From: @changzherui Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -142,8 +142,16 @@ int64_t LSTM::get_num_directions() const { | |||||
| auto value_ptr = this->GetAttr(kNumDirections); | auto value_ptr = this->GetAttr(kNumDirections); | ||||
| return GetValue<int64_t>(value_ptr); | return GetValue<int64_t>(value_ptr); | ||||
| } | } | ||||
| void LSTM::set_zoneout_cell(float zoneout_cell) { AddAttr(kZoneoutCell, MakeValue(zoneout_cell)); } | |||||
| float LSTM::get_zoneout_cell() const { return GetValue<float>(this->GetAttr(kZoneoutCell)); } | |||||
| void LSTM::set_zoneout_hidden(float zoneout_hidden) { AddAttr(kZoneoutHidden, MakeValue(zoneout_hidden)); } | |||||
| float LSTM::get_zoneout_hidden() const { return GetValue<float>(this->GetAttr(kZoneoutHidden)); } | |||||
| void LSTM::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias, | void LSTM::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias, | ||||
| const float dropout, const bool bidirectional) { | |||||
| const float dropout, const bool bidirectional, const float zoneout_cell, const float zoneout_hidden) { | |||||
| this->set_input_size(input_size); | this->set_input_size(input_size); | ||||
| this->set_hidden_size(hidden_size); | this->set_hidden_size(hidden_size); | ||||
| this->set_num_layers(num_layers); | this->set_num_layers(num_layers); | ||||
| @@ -155,6 +163,8 @@ void LSTM::Init(const int64_t input_size, const int64_t hidden_size, const int64 | |||||
| } else { | } else { | ||||
| this->set_num_directions(1); | this->set_num_directions(1); | ||||
| } | } | ||||
| this->set_zoneout_cell(zoneout_cell); | |||||
| this->set_zoneout_hidden(zoneout_hidden); | |||||
| } | } | ||||
| int64_t LSTM::get_good_ld(const int64_t dim, const int64_t type_size) { | int64_t LSTM::get_good_ld(const int64_t dim, const int64_t type_size) { | ||||
| @@ -37,7 +37,8 @@ class LSTM : public PrimitiveC { | |||||
| ~LSTM() = default; | ~LSTM() = default; | ||||
| MS_DECLARE_PARENT(LSTM, PrimitiveC); | MS_DECLARE_PARENT(LSTM, PrimitiveC); | ||||
| void Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias, | void Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias, | ||||
| const float dropout, const bool bidirectional = false); | |||||
| const float dropout, const bool bidirectional = false, const float zoneout_cell = 0.0f, | |||||
| const float zoneout_hidden = 0.0f); | |||||
| void set_input_size(const int64_t input_size); | void set_input_size(const int64_t input_size); | ||||
| int64_t get_input_size() const; | int64_t get_input_size() const; | ||||
| void set_hidden_size(const int64_t hidden_size); | void set_hidden_size(const int64_t hidden_size); | ||||
| @@ -52,6 +53,10 @@ class LSTM : public PrimitiveC { | |||||
| bool get_bidirectional() const; | bool get_bidirectional() const; | ||||
| void set_num_directions(const int64_t num_directions); | void set_num_directions(const int64_t num_directions); | ||||
| int64_t get_num_directions() const; | int64_t get_num_directions() const; | ||||
| void set_zoneout_cell(float zoneout_cell); | |||||
| float get_zoneout_cell() const; | |||||
| void set_zoneout_hidden(float zoneout_hidden); | |||||
| float get_zoneout_hidden() const; | |||||
| int64_t get_good_ld(const int64_t dim, const int64_t type_size); | int64_t get_good_ld(const int64_t dim, const int64_t type_size); | ||||
| }; | }; | ||||
| AbstractBasePtr LstmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr LstmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -227,6 +227,8 @@ constexpr auto kResetAfter = "reset_after"; | |||||
| constexpr auto kCoeff = "coeff"; | constexpr auto kCoeff = "coeff"; | ||||
| constexpr auto kIsDepthWise = "is_depth_wise"; | constexpr auto kIsDepthWise = "is_depth_wise"; | ||||
| constexpr auto kIsDepthWiseNative = "is_depth_wise_native"; | constexpr auto kIsDepthWiseNative = "is_depth_wise_native"; | ||||
| constexpr auto kZoneoutCell = "zoneout_cell"; | |||||
| constexpr auto kZoneoutHidden = "zoneout_hidden"; | |||||
| const std::set<TypeId> common_valid_types = { | const std::set<TypeId> common_valid_types = { | ||||
| kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt8, kNumberTypeUInt16, | kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt8, kNumberTypeUInt16, | ||||