Browse Source

!12778 modify lstm zoneout

From: @changzherui
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
f5766917ea
3 changed files with 19 additions and 2 deletions
  1. +11
    -1
      mindspore/core/ops/lstm.cc
  2. +6
    -1
      mindspore/core/ops/lstm.h
  3. +2
    -0
      mindspore/core/ops/op_utils.h

+ 11
- 1
mindspore/core/ops/lstm.cc View File

@@ -142,8 +142,16 @@ int64_t LSTM::get_num_directions() const {
auto value_ptr = this->GetAttr(kNumDirections);
return GetValue<int64_t>(value_ptr);
}
void LSTM::set_zoneout_cell(float zoneout_cell) { AddAttr(kZoneoutCell, MakeValue(zoneout_cell)); }

float LSTM::get_zoneout_cell() const { return GetValue<float>(this->GetAttr(kZoneoutCell)); }

void LSTM::set_zoneout_hidden(float zoneout_hidden) { AddAttr(kZoneoutHidden, MakeValue(zoneout_hidden)); }

float LSTM::get_zoneout_hidden() const { return GetValue<float>(this->GetAttr(kZoneoutHidden)); }

void LSTM::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias,
const float dropout, const bool bidirectional) {
const float dropout, const bool bidirectional, const float zoneout_cell, const float zoneout_hidden) {
this->set_input_size(input_size);
this->set_hidden_size(hidden_size);
this->set_num_layers(num_layers);
@@ -155,6 +163,8 @@ void LSTM::Init(const int64_t input_size, const int64_t hidden_size, const int64
} else {
this->set_num_directions(1);
}
this->set_zoneout_cell(zoneout_cell);
this->set_zoneout_hidden(zoneout_hidden);
}

int64_t LSTM::get_good_ld(const int64_t dim, const int64_t type_size) {


+ 6
- 1
mindspore/core/ops/lstm.h View File

@@ -37,7 +37,8 @@ class LSTM : public PrimitiveC {
~LSTM() = default;
MS_DECLARE_PARENT(LSTM, PrimitiveC);
void Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias,
const float dropout, const bool bidirectional = false);
const float dropout, const bool bidirectional = false, const float zoneout_cell = 0.0f,
const float zoneout_hidden = 0.0f);
void set_input_size(const int64_t input_size);
int64_t get_input_size() const;
void set_hidden_size(const int64_t hidden_size);
@@ -52,6 +53,10 @@ class LSTM : public PrimitiveC {
bool get_bidirectional() const;
void set_num_directions(const int64_t num_directions);
int64_t get_num_directions() const;
void set_zoneout_cell(float zoneout_cell);
float get_zoneout_cell() const;
void set_zoneout_hidden(float zoneout_hidden);
float get_zoneout_hidden() const;
int64_t get_good_ld(const int64_t dim, const int64_t type_size);
};
AbstractBasePtr LstmInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,


+ 2
- 0
mindspore/core/ops/op_utils.h View File

@@ -227,6 +227,8 @@ constexpr auto kResetAfter = "reset_after";
constexpr auto kCoeff = "coeff";
constexpr auto kIsDepthWise = "is_depth_wise";
constexpr auto kIsDepthWiseNative = "is_depth_wise_native";
constexpr auto kZoneoutCell = "zoneout_cell";
constexpr auto kZoneoutHidden = "zoneout_hidden";

const std::set<TypeId> common_valid_types = {
kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt8, kNumberTypeUInt16,


Loading…
Cancel
Save