Browse Source

!8001 Adapte nn.LSTM for Ascend.

From: @liu_xiao_93
Reviewed-by: @liangchenghui
Signed-off-by: @liangchenghui
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
e16661d2d9
1 changed files with 99 additions and 3 deletions
  1. +99
    -3
      mindspore/nn/layer/lstm.py

+ 99
- 3
mindspore/nn/layer/lstm.py View File

@@ -15,17 +15,27 @@
"""lstm"""
import math
import numpy as np
import mindspore.context as context
import mindspore.common.dtype as mstype
from mindspore.ops.primitive import constexpr
from mindspore._checkparam import Validator as validator
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor
from mindspore.nn.cell import Cell
from mindspore import nn
from mindspore.ops import operations as P


__all__ = ['LSTM', 'LSTMCell']


@constexpr
def _create_sequence_length(shape):
num_step, batch_size, _ = shape
sequence_length = Tensor(np.ones(batch_size, np.int32) * num_step, mstype.int32)
return sequence_length

class LSTM(Cell):
r"""
LSTM (Long Short-Term Memory) layer.
@@ -105,9 +115,20 @@ class LSTM(Cell):
validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
validator.check_positive_int(hidden_size, "hidden_size", self.cls_name)
validator.check_positive_int(num_layers, "num_layers", self.cls_name)
self.is_ascend = context.get_context("device_target") == "Ascend"

self.batch_first = batch_first
self.transpose = P.Transpose()
self.num_layers = num_layers
self.bidirectional = bidirectional
self.dropout = dropout
self.reverse_seq = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.concat = P.Concat(axis=0)
self.concat_2dim = P.Concat(axis=2)
self.cast = P.Cast()
self.shape = P.Shape()
if dropout != 0:
self.dropout_op = nn.Dropout(float(dropout))
self.lstm = P.LSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
@@ -117,23 +138,98 @@ class LSTM(Cell):

weight_size = 0
gate_size = 4 * hidden_size
stdv = 1 / math.sqrt(hidden_size)
num_directions = 2 if bidirectional else 1
b0 = np.zeros(gate_size, dtype=np.float16)
self.w_list = []
self.b_list = []
self.rnns_fw = P.DynamicRNN(forget_bias=0.0)
self.rnns_bw = P.DynamicRNN(forget_bias=0.0)
for layer in range(num_layers):
input_layer_size = input_size if layer == 0 else hidden_size * num_directions
increment_size = gate_size * input_layer_size
increment_size += gate_size * hidden_size
w_shape = input_size if layer == 0 else (num_directions * hidden_size)
w_np = np.random.uniform(-stdv, stdv, (w_shape + hidden_size, gate_size)).astype(np.float16)
self.w_list.append(Parameter(
initializer(Tensor(w_np), [w_shape + hidden_size, gate_size]), name='weight_fw' + str(layer)))
if has_bias:
increment_size += 2 * gate_size
b_np = np.random.uniform(-stdv, stdv, gate_size).astype(np.float16)
self.b_list.append(Parameter(initializer(Tensor(b_np), [gate_size]), name='bias_fw' + str(layer)))
else:
self.b_list.append(Parameter(initializer(Tensor(b0), [gate_size]), name='bias_fw' + str(layer)))
weight_size += increment_size * num_directions
stdv = 1 / math.sqrt(hidden_size)
if bidirectional:
w_bw_np = np.random.uniform(-stdv, stdv, (w_shape + hidden_size, gate_size)).astype(np.float16)
self.w_list.append(Parameter(initializer(Tensor(w_bw_np), [w_shape + hidden_size, gate_size]),
name='weight_bw' + str(layer)))
b_bw_np = np.random.uniform(-stdv, stdv, (4 * hidden_size)).astype(np.float16) if has_bias else b0
self.b_list.append(Parameter(initializer(Tensor(b_bw_np), [gate_size]), name='bias_bw' + str(layer)))
self.w_list = ParameterTuple(self.w_list)
self.b_list = ParameterTuple(self.b_list)
w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
self.weight = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight')

def _stacked_bi_dynamic_rnn(self, x, init_h, init_c, weight, bias):
"""stacked bidirectional dynamic_rnn"""
x_shape = self.shape(x)
sequence_length = _create_sequence_length(x_shape)
pre_layer = x
hn = ()
cn = ()
output = x
for i in range(self.num_layers):
offset = i * 2
weight_fw, weight_bw = weight[offset], weight[offset + 1]
bias_fw, bias_bw = bias[offset], bias[offset + 1]
init_h_fw, init_h_bw = init_h[offset:offset + 1, :, :], init_h[offset + 1:offset + 2, :, :]
init_c_fw, init_c_bw = init_c[offset:offset + 1, :, :], init_c[offset + 1:offset + 2, :, :]
bw_x = self.reverse_seq(pre_layer, sequence_length)
y, h, c, _, _, _, _, _ = self.rnns_fw(pre_layer, weight_fw, bias_fw, None, init_h_fw, init_c_fw)
y_bw, h_bw, c_bw, _, _, _, _, _ = self.rnns_bw(bw_x, weight_bw, bias_bw, None, init_h_bw, init_c_bw)
y_bw = self.reverse_seq(y_bw, sequence_length)
output = self.concat_2dim((y, y_bw))
pre_layer = self.dropout_op(output) if self.dropout else output
hn += (h[-1:, :, :],)
hn += (h_bw[-1:, :, :],)
cn += (c[-1:, :, :],)
cn += (c_bw[-1:, :, :],)
status_h = self.concat(hn)
status_c = self.concat(cn)
return output, status_h, status_c

def _stacked_dynamic_rnn(self, x, init_h, init_c, weight, bias):
"""stacked mutil_layer dynamic_rnn"""
pre_layer = x
hn = ()
cn = ()
y = 0
for i in range(self.num_layers):
weight_fw, bias_bw = weight[i], bias[i]
init_h_fw, init_c_bw = init_h[i:i + 1, :, :], init_c[i:i + 1, :, :]
y, h, c, _, _, _, _, _ = self.rnns_fw(pre_layer, weight_fw, bias_bw, None, init_h_fw, init_c_bw)
pre_layer = self.dropout_op(y) if self.dropout else y
hn += (h[-1:, :, :],)
cn += (c[-1:, :, :],)
status_h = self.concat(hn)
status_c = self.concat(cn)
return y, status_h, status_c

def construct(self, x, hx):
if self.batch_first:
x = self.transpose(x, (1, 0, 2))
h, c = hx
x, h, c, _, _ = self.lstm(x, h, c, self.weight)
if self.is_ascend:
x = self.cast(x, mstype.float16)
h = self.cast(h, mstype.float16)
c = self.cast(c, mstype.float16)
if self.bidirectional:
x, h, c = self._stacked_bi_dynamic_rnn(x, h, c, self.w_list, self.b_list)
else:
x, h, c = self._stacked_dynamic_rnn(x, h, c, self.w_list, self.b_list)
else:
x, h, c, _, _ = self.lstm(x, h, c, self.weight)
if self.batch_first:
x = self.transpose(x, (1, 0, 2))
return x, (h, c)


Loading…
Cancel
Save