|
|
|
@@ -142,8 +142,16 @@ int64_t LSTM::get_num_directions() const { |
|
|
|
auto value_ptr = this->GetAttr(kNumDirections); |
|
|
|
return GetValue<int64_t>(value_ptr); |
|
|
|
} |
|
|
|
void LSTM::set_zoneout_cell(float zoneout_cell) { AddAttr(kZoneoutCell, MakeValue(zoneout_cell)); } |
|
|
|
|
|
|
|
float LSTM::get_zoneout_cell() const { return GetValue<float>(this->GetAttr(kZoneoutCell)); } |
|
|
|
|
|
|
|
void LSTM::set_zoneout_hidden(float zoneout_hidden) { AddAttr(kZoneoutHidden, MakeValue(zoneout_hidden)); } |
|
|
|
|
|
|
|
float LSTM::get_zoneout_hidden() const { return GetValue<float>(this->GetAttr(kZoneoutHidden)); } |
|
|
|
|
|
|
|
void LSTM::Init(const int64_t input_size, const int64_t hidden_size, const int64_t num_layers, const bool has_bias, |
|
|
|
const float dropout, const bool bidirectional) { |
|
|
|
const float dropout, const bool bidirectional, const float zoneout_cell, const float zoneout_hidden) { |
|
|
|
this->set_input_size(input_size); |
|
|
|
this->set_hidden_size(hidden_size); |
|
|
|
this->set_num_layers(num_layers); |
|
|
|
@@ -155,6 +163,8 @@ void LSTM::Init(const int64_t input_size, const int64_t hidden_size, const int64 |
|
|
|
} else { |
|
|
|
this->set_num_directions(1); |
|
|
|
} |
|
|
|
this->set_zoneout_cell(zoneout_cell); |
|
|
|
this->set_zoneout_hidden(zoneout_hidden); |
|
|
|
} |
|
|
|
|
|
|
|
int64_t LSTM::get_good_ld(const int64_t dim, const int64_t type_size) { |
|
|
|
|