|
|
|
@@ -154,8 +154,12 @@ class LSTM(Cell): |
|
|
|
self.concat_2dim = P.Concat(axis=2) |
|
|
|
self.cast = P.Cast() |
|
|
|
self.shape = P.Shape() |
|
|
|
if dropout != 0: |
|
|
|
self.dropout_op = nn.Dropout(float(dropout)) |
|
|
|
if dropout < 0 or dropout > 1: |
|
|
|
raise ValueError("For LSTM, dropout must be a number in range [0, 1], but got {}".format(dropout)) |
|
|
|
if dropout == 1: |
|
|
|
self.dropout_op = P.ZerosLike() |
|
|
|
else: |
|
|
|
self.dropout_op = nn.Dropout(float(1 - dropout)) |
|
|
|
b0 = np.zeros(gate_size, dtype=np.float16) |
|
|
|
self.w_list = [] |
|
|
|
self.b_list = [] |
|
|
|
|