Browse Source

!10243 Fix output for Ascend backend of nn.LSTM when dropout is 1.0.

From: @liu_xiao_93
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
7ea0a14795
1 changed files with 6 additions and 2 deletions
  1. +6
    -2
      mindspore/nn/layer/lstm.py

+ 6
- 2
mindspore/nn/layer/lstm.py View File

@@ -154,8 +154,12 @@ class LSTM(Cell):
self.concat_2dim = P.Concat(axis=2)
self.cast = P.Cast()
self.shape = P.Shape()
if dropout != 0:
self.dropout_op = nn.Dropout(float(dropout))
if dropout < 0 or dropout > 1:
raise ValueError("For LSTM, dropout must be a number in range [0, 1], but got {}".format(dropout))
if dropout == 1:
self.dropout_op = P.ZerosLike()
else:
self.dropout_op = nn.Dropout(float(1 - dropout))
b0 = np.zeros(gate_size, dtype=np.float16)
self.w_list = []
self.b_list = []


Loading…
Cancel
Save