diff --git a/mindspore/nn/layer/lstm.py b/mindspore/nn/layer/lstm.py index a3ae9ca67d..b0d5a72b2a 100755 --- a/mindspore/nn/layer/lstm.py +++ b/mindspore/nn/layer/lstm.py @@ -15,17 +15,27 @@ """lstm""" import math import numpy as np +import mindspore.context as context +import mindspore.common.dtype as mstype +from mindspore.ops.primitive import constexpr from mindspore._checkparam import Validator as validator from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter +from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.tensor import Tensor from mindspore.nn.cell import Cell +from mindspore import nn from mindspore.ops import operations as P __all__ = ['LSTM', 'LSTMCell'] +@constexpr +def _create_sequence_length(shape): + num_step, batch_size, _ = shape + sequence_length = Tensor(np.ones(batch_size, np.int32) * num_step, mstype.int32) + return sequence_length + class LSTM(Cell): r""" LSTM (Long Short-Term Memory) layer. @@ -105,9 +115,20 @@ class LSTM(Cell): validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) validator.check_positive_int(hidden_size, "hidden_size", self.cls_name) validator.check_positive_int(num_layers, "num_layers", self.cls_name) + self.is_ascend = context.get_context("device_target") == "Ascend" self.batch_first = batch_first self.transpose = P.Transpose() + self.num_layers = num_layers + self.bidirectional = bidirectional + self.dropout = dropout + self.reverse_seq = P.ReverseSequence(batch_dim=1, seq_dim=0) + self.concat = P.Concat(axis=0) + self.concat_2dim = P.Concat(axis=2) + self.cast = P.Cast() + self.shape = P.Shape() + if dropout != 0: + self.dropout_op = nn.Dropout(float(dropout)) self.lstm = P.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, @@ -117,23 +138,98 @@ class LSTM(Cell): weight_size = 0 gate_size = 4 * hidden_size + stdv = 1 / math.sqrt(hidden_size) num_directions = 2 if bidirectional else 1 + b0 = np.zeros(gate_size, dtype=np.float16) + self.w_list = [] + self.b_list = [] + self.rnns_fw = P.DynamicRNN(forget_bias=0.0) + self.rnns_bw = P.DynamicRNN(forget_bias=0.0) for layer in range(num_layers): input_layer_size = input_size if layer == 0 else hidden_size * num_directions increment_size = gate_size * input_layer_size increment_size += gate_size * hidden_size + w_shape = input_size if layer == 0 else (num_directions * hidden_size) + w_np = np.random.uniform(-stdv, stdv, (w_shape + hidden_size, gate_size)).astype(np.float16) + self.w_list.append(Parameter( + initializer(Tensor(w_np), [w_shape + hidden_size, gate_size]), name='weight_fw' + str(layer))) if has_bias: increment_size += 2 * gate_size + b_np = np.random.uniform(-stdv, stdv, gate_size).astype(np.float16) + self.b_list.append(Parameter(initializer(Tensor(b_np), [gate_size]), name='bias_fw' + str(layer))) + else: + self.b_list.append(Parameter(initializer(Tensor(b0), [gate_size]), name='bias_fw' + str(layer))) weight_size += increment_size * num_directions - stdv = 1 / math.sqrt(hidden_size) + if bidirectional: + w_bw_np = np.random.uniform(-stdv, stdv, (w_shape + hidden_size, gate_size)).astype(np.float16) + self.w_list.append(Parameter(initializer(Tensor(w_bw_np), [w_shape + hidden_size, gate_size]), + name='weight_bw' + str(layer))) + b_bw_np = np.random.uniform(-stdv, stdv, (4 * hidden_size)).astype(np.float16) if has_bias else b0 + self.b_list.append(Parameter(initializer(Tensor(b_bw_np), [gate_size]), name='bias_bw' + str(layer))) + self.w_list = ParameterTuple(self.w_list) + self.b_list = ParameterTuple(self.b_list) w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32) self.weight = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight') + def _stacked_bi_dynamic_rnn(self, x, init_h, init_c, weight, bias): + """stacked bidirectional dynamic_rnn""" + x_shape = self.shape(x) + sequence_length = _create_sequence_length(x_shape) + pre_layer = x + hn = () + cn = () + output = x + for i in range(self.num_layers): + offset = i * 2 + weight_fw, weight_bw = weight[offset], weight[offset + 1] + bias_fw, bias_bw = bias[offset], bias[offset + 1] + init_h_fw, init_h_bw = init_h[offset:offset + 1, :, :], init_h[offset + 1:offset + 2, :, :] + init_c_fw, init_c_bw = init_c[offset:offset + 1, :, :], init_c[offset + 1:offset + 2, :, :] + bw_x = self.reverse_seq(pre_layer, sequence_length) + y, h, c, _, _, _, _, _ = self.rnns_fw(pre_layer, weight_fw, bias_fw, None, init_h_fw, init_c_fw) + y_bw, h_bw, c_bw, _, _, _, _, _ = self.rnns_bw(bw_x, weight_bw, bias_bw, None, init_h_bw, init_c_bw) + y_bw = self.reverse_seq(y_bw, sequence_length) + output = self.concat_2dim((y, y_bw)) + pre_layer = self.dropout_op(output) if self.dropout else output + hn += (h[-1:, :, :],) + hn += (h_bw[-1:, :, :],) + cn += (c[-1:, :, :],) + cn += (c_bw[-1:, :, :],) + status_h = self.concat(hn) + status_c = self.concat(cn) + return output, status_h, status_c + + def _stacked_dynamic_rnn(self, x, init_h, init_c, weight, bias): + """stacked mutil_layer dynamic_rnn""" + pre_layer = x + hn = () + cn = () + y = 0 + for i in range(self.num_layers): + weight_fw, bias_bw = weight[i], bias[i] + init_h_fw, init_c_bw = init_h[i:i + 1, :, :], init_c[i:i + 1, :, :] + y, h, c, _, _, _, _, _ = self.rnns_fw(pre_layer, weight_fw, bias_bw, None, init_h_fw, init_c_bw) + pre_layer = self.dropout_op(y) if self.dropout else y + hn += (h[-1:, :, :],) + cn += (c[-1:, :, :],) + status_h = self.concat(hn) + status_c = self.concat(cn) + return y, status_h, status_c + def construct(self, x, hx): if self.batch_first: x = self.transpose(x, (1, 0, 2)) h, c = hx - x, h, c, _, _ = self.lstm(x, h, c, self.weight) + if self.is_ascend: + x = self.cast(x, mstype.float16) + h = self.cast(h, mstype.float16) + c = self.cast(c, mstype.float16) + if self.bidirectional: + x, h, c = self._stacked_bi_dynamic_rnn(x, h, c, self.w_list, self.b_list) + else: + x, h, c = self._stacked_dynamic_rnn(x, h, c, self.w_list, self.b_list) + else: + x, h, c, _, _ = self.lstm(x, h, c, self.weight) if self.batch_first: x = self.transpose(x, (1, 0, 2)) return x, (h, c)