From f1a7274f77836775370b497d0256e99d26e9a4c1 Mon Sep 17 00:00:00 2001 From: changzherui Date: Tue, 2 Mar 2021 12:31:07 +0800 Subject: [PATCH] modify lstm --- mindspore/core/ops/lstm.cc | 12 +++++++++++- mindspore/core/ops/lstm.h | 7 ++++++- mindspore/core/ops/op_utils.h | 2 ++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/mindspore/core/ops/lstm.cc b/mindspore/core/ops/lstm.cc index 79a5e60286..8c51b3cde2 100644 --- a/mindspore/core/ops/lstm.cc +++ b/mindspore/core/ops/lstm.cc @@ -142,8 +142,16 @@ int64_t LSTM::get_num_directions() const { auto value_ptr = this->GetAttr(kNumDirections); return GetValue(value_ptr); } +void LSTM::set_zoneout_cell(float zoneout_cell) { AddAttr(kZoneoutCell, MakeValue(zoneout_cell)); } + +float LSTM::get_zoneout_cell() const { return GetValue(this->GetAttr(kZoneoutCell)); } + +void LSTM::set_zoneout_hidden(float zoneout_hidden) { AddAttr(kZoneoutHidden, MakeValue(zoneout_hidden)); } + +float LSTM::get_zoneout_hidden() const { return GetValue(this->GetAttr(kZoneoutHidden)); } + void LSTM::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias, - const float dropout, const bool bidirectional) { + const float dropout, const bool bidirectional, const float zoneout_cell, const float zoneout_hidden) { this->set_input_size(input_size); this->set_hidden_size(hidden_size); this->set_num_layers(num_layers); @@ -155,6 +163,8 @@ void LSTM::Init(const int64_t input_size, const int64_t hidden_size, const int64 } else { this->set_num_directions(1); } + this->set_zoneout_cell(zoneout_cell); + this->set_zoneout_hidden(zoneout_hidden); } int64_t LSTM::get_good_ld(const int64_t dim, const int64_t type_size) { diff --git a/mindspore/core/ops/lstm.h b/mindspore/core/ops/lstm.h index d45e3f05ac..4d128e8896 100644 --- a/mindspore/core/ops/lstm.h +++ b/mindspore/core/ops/lstm.h @@ -37,7 +37,8 @@ class LSTM : public PrimitiveC { ~LSTM() = default; MS_DECLARE_PARENT(LSTM, PrimitiveC); void Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias, - const float dropout, const bool bidirectional = false); + const float dropout, const bool bidirectional = false, const float zoneout_cell = 0.0f, + const float zoneout_hidden = 0.0f); void set_input_size(const int64_t input_size); int64_t get_input_size() const; void set_hidden_size(const int64_t hidden_size); @@ -52,6 +53,10 @@ class LSTM : public PrimitiveC { bool get_bidirectional() const; void set_num_directions(const int64_t num_directions); int64_t get_num_directions() const; + void set_zoneout_cell(float zoneout_cell); + float get_zoneout_cell() const; + void set_zoneout_hidden(float zoneout_hidden); + float get_zoneout_hidden() const; int64_t get_good_ld(const int64_t dim, const int64_t type_size); }; AbstractBasePtr LstmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/ops/op_utils.h b/mindspore/core/ops/op_utils.h index 3902e73e35..47d7f7eac8 100644 --- a/mindspore/core/ops/op_utils.h +++ b/mindspore/core/ops/op_utils.h @@ -227,6 +227,8 @@ constexpr auto kResetAfter = "reset_after"; constexpr auto kCoeff = "coeff"; constexpr auto kIsDepthWise = "is_depth_wise"; constexpr auto kIsDepthWiseNative = "is_depth_wise_native"; +constexpr auto kZoneoutCell = "zoneout_cell"; +constexpr auto kZoneoutHidden = "zoneout_hidden"; const std::set common_valid_types = { kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt8, kNumberTypeUInt16,