Merge pull request !6374 from caojian05/ms_master_lstm_api_optimationtags/v1.0.0
| @@ -14,12 +14,12 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """lstm""" | """lstm""" | ||||
| import math | import math | ||||
| import numpy as np | import numpy as np | ||||
| import mindspore.nn as nn | |||||
| from mindspore import context | |||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore.common.parameter import Parameter, ParameterTuple | |||||
| from mindspore.common.parameter import Parameter | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.nn.cell import Cell | from mindspore.nn.cell import Cell | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| @@ -118,83 +118,41 @@ class LSTM(Cell): | |||||
| dropout=0, | dropout=0, | ||||
| bidirectional=False): | bidirectional=False): | ||||
| super(LSTM, self).__init__() | super(LSTM, self).__init__() | ||||
| self.input_size = input_size | |||||
| self.hidden_size = hidden_size | |||||
| self.num_layers = num_layers | |||||
| self.has_bias = has_bias | |||||
| self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) | |||||
| self.hidden_size = validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.cls_name) | |||||
| self.num_layers = validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.cls_name) | |||||
| self.dropout = float(dropout) | |||||
| self.bidirectional = bidirectional | |||||
| if self.batch_first: | |||||
| self.transpose1 = P.Transpose() | |||||
| self.transpose2 = P.Transpose() | |||||
| num_directions = 2 if self.bidirectional else 1 | |||||
| self.cpu_target = False | |||||
| enable_debug = context.get_context("enable_debug_runtime") | |||||
| if context.get_context("device_target") == "CPU" and not enable_debug: | |||||
| self.cpu_target = True | |||||
| if not self.cpu_target: | |||||
| self.lstm = P.LSTM(input_size=self.input_size, | |||||
| hidden_size=self.hidden_size, | |||||
| num_layers=self.num_layers, | |||||
| has_bias=self.has_bias, | |||||
| bidirectional=self.bidirectional, | |||||
| dropout=self.dropout) | |||||
| weight_size = 0 | |||||
| gate_size = 4 * self.hidden_size | |||||
| for layer in range(self.num_layers): | |||||
| input_layer_size = self.input_size if layer == 0 else self.hidden_size * num_directions | |||||
| increment_size = gate_size * input_layer_size | |||||
| increment_size += gate_size * self.hidden_size | |||||
| if self.has_bias: | |||||
| increment_size += 2 * gate_size | |||||
| weight_size += increment_size * num_directions | |||||
| stdv = 1 / math.sqrt(hidden_size) | |||||
| w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32) | |||||
| self.weight = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight') | |||||
| else: | |||||
| input_size_list = [] | |||||
| input_size_list.append(self.input_size) | |||||
| for i in range(self.num_layers - 1): | |||||
| input_size_list.append(self.hidden_size * num_directions) | |||||
| weights = [] | |||||
| layers = [] | |||||
| bias_size = 0 if not self.has_bias else num_directions * self.hidden_size * 4 | |||||
| stdv = 1 / math.sqrt(hidden_size) | |||||
| for i in range(num_layers): | |||||
| weight_size = (input_size_list[i] + self.hidden_size) * num_directions * self.hidden_size * 4 | |||||
| if has_bias: | |||||
| weight_size = weight_size + bias_size | |||||
| w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32) | |||||
| weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name='weight' + str(i))) | |||||
| layers.append(nn.LSTMCell(input_size=input_size_list[i], | |||||
| hidden_size=self.hidden_size, | |||||
| has_bias=self.has_bias, | |||||
| bidirectional=self.bidirectional, | |||||
| dropout=self.dropout)) | |||||
| self.lstms = layers | |||||
| self.weight = ParameterTuple(tuple(weights)) | |||||
| self.fill = P.Fill() | |||||
| self.shape = P.Shape() | |||||
| validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) | |||||
| validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.cls_name) | |||||
| validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.cls_name) | |||||
| self.batch_first = batch_first | |||||
| self.transpose = P.Transpose() | |||||
| self.lstm = P.LSTM(input_size=input_size, | |||||
| hidden_size=hidden_size, | |||||
| num_layers=num_layers, | |||||
| has_bias=has_bias, | |||||
| bidirectional=bidirectional, | |||||
| dropout=float(dropout)) | |||||
| weight_size = 0 | |||||
| gate_size = 4 * hidden_size | |||||
| num_directions = 2 if bidirectional else 1 | |||||
| for layer in range(num_layers): | |||||
| input_layer_size = input_size if layer == 0 else hidden_size * num_directions | |||||
| increment_size = gate_size * input_layer_size | |||||
| increment_size += gate_size * hidden_size | |||||
| if has_bias: | |||||
| increment_size += 2 * gate_size | |||||
| weight_size += increment_size * num_directions | |||||
| stdv = 1 / math.sqrt(hidden_size) | |||||
| w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32) | |||||
| self.weight = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight') | |||||
| def construct(self, x, hx): | def construct(self, x, hx): | ||||
| if self.batch_first: | if self.batch_first: | ||||
| x = self.transpose1(x, (1, 0, 2)) | |||||
| if not self.cpu_target: | |||||
| h, c = hx | |||||
| output, h, c, _, _ = self.lstm(x, h, c, self.weight) | |||||
| if self.batch_first: | |||||
| output = self.transpose2(output, (1, 0, 2)) | |||||
| return (output, (h, c)) | |||||
| x = self.transpose(x, (1, 0, 2)) | |||||
| h, c = hx | h, c = hx | ||||
| output, hn, cn, _, _ = self.lstms[0](x, h[0], c[0], self.weight[0]) | |||||
| for i in range(1, self.num_layers): | |||||
| output, hn, cn, _, _ = self.lstms[i](output, h[i], c[i], self.weight[i]) | |||||
| x, h, c, _, _ = self.lstm(x, h, c, self.weight) | |||||
| if self.batch_first: | if self.batch_first: | ||||
| output = self.transpose2(output, (1, 0, 2)) | |||||
| return (output, (hn, cn)) | |||||
| x = self.transpose(x, (1, 0, 2)) | |||||
| return x, (h, c) | |||||
| class LSTMCell(Cell): | class LSTMCell(Cell): | ||||
| @@ -291,30 +249,19 @@ class LSTMCell(Cell): | |||||
| dropout=0, | dropout=0, | ||||
| bidirectional=False): | bidirectional=False): | ||||
| super(LSTMCell, self).__init__() | super(LSTMCell, self).__init__() | ||||
| self.input_size = input_size | |||||
| self.hidden_size = hidden_size | |||||
| self.has_bias = has_bias | |||||
| self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) | self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) | ||||
| self.dropout = float(dropout) | |||||
| self.bidirectional = bidirectional | |||||
| self.num_directions = 1 | |||||
| if self.bidirectional: | |||||
| self.num_directions = 2 | |||||
| if self.batch_first: | |||||
| self.transpose1 = P.Transpose() | |||||
| self.transpose2 = P.Transpose() | |||||
| self.lstm = P.LSTM(input_size=self.input_size, | |||||
| hidden_size=self.hidden_size, | |||||
| self.transpose = P.Transpose() | |||||
| self.lstm = P.LSTM(input_size=input_size, | |||||
| hidden_size=hidden_size, | |||||
| num_layers=1, | num_layers=1, | ||||
| has_bias=self.has_bias, | |||||
| bidirectional=self.bidirectional, | |||||
| dropout=self.dropout) | |||||
| has_bias=has_bias, | |||||
| bidirectional=bidirectional, | |||||
| dropout=float(dropout)) | |||||
| def construct(self, x, h, c, w): | def construct(self, x, h, c, w): | ||||
| if self.batch_first: | if self.batch_first: | ||||
| x = self.transpose1(x, (1, 0, 2)) | |||||
| output, hn, cn, _, _ = self.lstm(x, h, c, w) | |||||
| x = self.transpose(x, (1, 0, 2)) | |||||
| x, h, c, _, _ = self.lstm(x, h, c, w) | |||||
| if self.batch_first: | if self.batch_first: | ||||
| output = self.transpose2(output, (1, 0, 2)) | |||||
| return output, hn, cn, _, _ | |||||
| x = self.transpose(x, (1, 0, 2)) | |||||
| return x, h, c, _, _ | |||||
| @@ -13,40 +13,108 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """LSTM.""" | """LSTM.""" | ||||
| import math | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore import Tensor, nn, context | |||||
| from mindspore import Tensor, nn, context, Parameter, ParameterTuple | |||||
| from mindspore.common.initializer import initializer | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| STACK_LSTM_DEVICE = ["CPU"] | |||||
| # Initialize short-term memory (h) and long-term memory (c) to 0 | # Initialize short-term memory (h) and long-term memory (c) to 0 | ||||
| def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): | def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): | ||||
| """init default input.""" | """init default input.""" | ||||
| num_directions = 1 | |||||
| if bidirectional: | |||||
| num_directions = 2 | |||||
| if context.get_context("device_target") == "CPU": | |||||
| h_list = [] | |||||
| c_list = [] | |||||
| i = 0 | |||||
| while i < num_layers: | |||||
| hi = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)) | |||||
| h_list.append(hi) | |||||
| ci = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)) | |||||
| c_list.append(ci) | |||||
| i = i + 1 | |||||
| h = tuple(h_list) | |||||
| c = tuple(c_list) | |||||
| return h, c | |||||
| h = Tensor( | |||||
| np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) | |||||
| c = Tensor( | |||||
| np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) | |||||
| num_directions = 2 if bidirectional else 1 | |||||
| h = Tensor(np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) | |||||
| c = Tensor(np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32)) | |||||
| return h, c | return h, c | ||||
| def stack_lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): | |||||
| """init default input.""" | |||||
| num_directions = 2 if bidirectional else 1 | |||||
| h_list = c_list = [] | |||||
| for _ in range(num_layers): | |||||
| h_list.append(Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32))) | |||||
| c_list.append(Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32))) | |||||
| h, c = tuple(h_list), tuple(c_list) | |||||
| return h, c | |||||
| class StackLSTM(nn.Cell): | |||||
| """ | |||||
| Stack multi-layers LSTM together. | |||||
| """ | |||||
| def __init__(self, | |||||
| input_size, | |||||
| hidden_size, | |||||
| num_layers=1, | |||||
| has_bias=True, | |||||
| batch_first=False, | |||||
| dropout=0.0, | |||||
| bidirectional=False): | |||||
| super(StackLSTM, self).__init__() | |||||
| self.num_layers = num_layers | |||||
| self.batch_first = batch_first | |||||
| self.transpose = P.Transpose() | |||||
| # direction number | |||||
| num_directions = 2 if bidirectional else 1 | |||||
| # input_size list | |||||
| input_size_list = [input_size] | |||||
| for i in range(num_layers - 1): | |||||
| input_size_list.append(hidden_size * num_directions) | |||||
| # layers | |||||
| layers = [] | |||||
| for i in range(num_layers): | |||||
| layers.append(nn.LSTMCell(input_size=input_size_list[i], | |||||
| hidden_size=hidden_size, | |||||
| has_bias=has_bias, | |||||
| batch_first=batch_first, | |||||
| bidirectional=bidirectional, | |||||
| dropout=dropout)) | |||||
| # weights | |||||
| weights = [] | |||||
| for i in range(num_layers): | |||||
| # weight size | |||||
| weight_size = (input_size_list[i] + hidden_size) * num_directions * hidden_size * 4 | |||||
| if has_bias: | |||||
| bias_size = num_directions * hidden_size * 4 | |||||
| weight_size = weight_size + bias_size | |||||
| # numpy weight | |||||
| stdv = 1 / math.sqrt(hidden_size) | |||||
| w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32) | |||||
| # lstm weight | |||||
| weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name="weight" + str(i))) | |||||
| # | |||||
| self.lstms = layers | |||||
| self.weight = ParameterTuple(tuple(weights)) | |||||
| def construct(self, x, hx): | |||||
| """construct""" | |||||
| if self.batch_first: | |||||
| x = self.transpose(x, (1, 0, 2)) | |||||
| # stack lstm | |||||
| h, c = hx | |||||
| hn = cn = None | |||||
| for i in range(self.num_layers): | |||||
| x, hn, cn, _, _ = self.lstms[i](x, h[i], c[i], self.weight[i]) | |||||
| if self.batch_first: | |||||
| x = self.transpose(x, (1, 0, 2)) | |||||
| return x, (hn, cn) | |||||
| class SentimentNet(nn.Cell): | class SentimentNet(nn.Cell): | ||||
| """Sentiment network structure.""" | """Sentiment network structure.""" | ||||
| @@ -67,14 +135,25 @@ class SentimentNet(nn.Cell): | |||||
| self.embedding.embedding_table.requires_grad = False | self.embedding.embedding_table.requires_grad = False | ||||
| self.trans = P.Transpose() | self.trans = P.Transpose() | ||||
| self.perm = (1, 0, 2) | self.perm = (1, 0, 2) | ||||
| self.encoder = nn.LSTM(input_size=embed_size, | |||||
| hidden_size=num_hiddens, | |||||
| num_layers=num_layers, | |||||
| has_bias=True, | |||||
| bidirectional=bidirectional, | |||||
| dropout=0.0) | |||||
| self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional) | |||||
| if context.get_context("device_target") in STACK_LSTM_DEVICE: | |||||
| # stack lstm by user | |||||
| self.encoder = StackLSTM(input_size=embed_size, | |||||
| hidden_size=num_hiddens, | |||||
| num_layers=num_layers, | |||||
| has_bias=True, | |||||
| bidirectional=bidirectional, | |||||
| dropout=0.0) | |||||
| self.h, self.c = stack_lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional) | |||||
| else: | |||||
| # standard lstm | |||||
| self.encoder = nn.LSTM(input_size=embed_size, | |||||
| hidden_size=num_hiddens, | |||||
| num_layers=num_layers, | |||||
| has_bias=True, | |||||
| bidirectional=bidirectional, | |||||
| dropout=0.0) | |||||
| self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional) | |||||
| self.concat = P.Concat(1) | self.concat = P.Concat(1) | ||||
| if bidirectional: | if bidirectional: | ||||
| @@ -12,6 +12,7 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| import math | |||||
| import pytest | import pytest | ||||
| import numpy as np | import numpy as np | ||||
| @@ -20,12 +21,83 @@ import mindspore.context as context | |||||
| from mindspore.common.api import ms_function | from mindspore.common.api import ms_function | ||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import operations as P | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.common.parameter import ParameterTuple, Parameter | from mindspore.common.parameter import ParameterTuple, Parameter | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | ||||
| class StackLSTM(nn.Cell): | |||||
| """ | |||||
| Stack multi-layers LSTM together. | |||||
| """ | |||||
| def __init__(self, | |||||
| input_size, | |||||
| hidden_size, | |||||
| num_layers=1, | |||||
| has_bias=True, | |||||
| batch_first=False, | |||||
| dropout=0.0, | |||||
| bidirectional=False): | |||||
| super(StackLSTM, self).__init__() | |||||
| self.num_layers = num_layers | |||||
| self.batch_first = batch_first | |||||
| self.transpose = P.Transpose() | |||||
| # direction number | |||||
| num_directions = 2 if bidirectional else 1 | |||||
| # input_size list | |||||
| input_size_list = [input_size] | |||||
| for i in range(num_layers - 1): | |||||
| input_size_list.append(hidden_size * num_directions) | |||||
| # layers | |||||
| layers = [] | |||||
| for i in range(num_layers): | |||||
| layers.append(nn.LSTMCell(input_size=input_size_list[i], | |||||
| hidden_size=hidden_size, | |||||
| has_bias=has_bias, | |||||
| batch_first=batch_first, | |||||
| bidirectional=bidirectional, | |||||
| dropout=dropout)) | |||||
| # weights | |||||
| weights = [] | |||||
| for i in range(num_layers): | |||||
| # weight size | |||||
| weight_size = (input_size_list[i] + hidden_size) * num_directions * hidden_size * 4 | |||||
| if has_bias: | |||||
| bias_size = num_directions * hidden_size * 4 | |||||
| weight_size = weight_size + bias_size | |||||
| # numpy weight | |||||
| stdv = 1 / math.sqrt(hidden_size) | |||||
| w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32) | |||||
| # lstm weight | |||||
| weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name="weight" + str(i))) | |||||
| # | |||||
| self.lstms = layers | |||||
| self.weight = ParameterTuple(tuple(weights)) | |||||
| def construct(self, x, hx): | |||||
| """construct""" | |||||
| if self.batch_first: | |||||
| x = self.transpose(x, (1, 0, 2)) | |||||
| # stack lstm | |||||
| h, c = hx | |||||
| hn = cn = None | |||||
| for i in range(self.num_layers): | |||||
| x, hn, cn, _, _ = self.lstms[i](x, h[i], c[i], self.weight[i]) | |||||
| if self.batch_first: | |||||
| x = self.transpose(x, (1, 0, 2)) | |||||
| return x, (hn, cn) | |||||
| class LstmNet(nn.Cell): | class LstmNet(nn.Cell): | ||||
| def __init__(self, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): | def __init__(self, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): | ||||
| super(LstmNet, self).__init__() | super(LstmNet, self).__init__() | ||||
| @@ -34,7 +106,7 @@ class LstmNet(nn.Cell): | |||||
| if bidirectional: | if bidirectional: | ||||
| num_directions = 2 | num_directions = 2 | ||||
| self.lstm = nn.LSTM(input_size, hidden_size, num_layers, has_bias, bidirectional, dropout) | |||||
| self.lstm = StackLSTM(input_size, hidden_size, num_layers, has_bias, bidirectional, dropout) | |||||
| input_np = np.array([[[0.6755, -1.6607, 0.1367], [0.4276, -0.7850, -0.3758]], | input_np = np.array([[[0.6755, -1.6607, 0.1367], [0.4276, -0.7850, -0.3758]], | ||||
| [[-0.6424, -0.6095, 0.6639], [0.7918, 0.4147, -0.5089]], | [[-0.6424, -0.6095, 0.6639], [0.7918, 0.4147, -0.5089]], | ||||
| [[-1.5612, 0.0120, -0.7289], [-0.6656, -0.6626, -0.5883]], | [[-1.5612, 0.0120, -0.7289], [-0.6656, -0.6626, -0.5883]], | ||||
| @@ -137,8 +209,8 @@ class MultiLayerBiLstmNet(nn.Cell): | |||||
| if bidirectional: | if bidirectional: | ||||
| num_directions = 2 | num_directions = 2 | ||||
| self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, has_bias=has_bias, | |||||
| bidirectional=bidirectional, dropout=dropout) | |||||
| self.lstm = StackLSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, has_bias=has_bias, | |||||
| bidirectional=bidirectional, dropout=dropout) | |||||
| input_np = np.array([[[-0.1887, -0.4144, -0.0235, 0.7489, 0.7522, 0.5969, 0.3342, 1.2198, 0.6786, -0.9404], | input_np = np.array([[[-0.1887, -0.4144, -0.0235, 0.7489, 0.7522, 0.5969, 0.3342, 1.2198, 0.6786, -0.9404], | ||||
| [-0.8643, -1.6835, -2.4965, 2.8093, 0.1741, 0.2707, 0.7387, -0.0939, -1.7990, 0.4765]], | [-0.8643, -1.6835, -2.4965, 2.8093, 0.1741, 0.2707, 0.7387, -0.0939, -1.7990, 0.4765]], | ||||
| @@ -264,8 +336,8 @@ class Net(nn.Cell): | |||||
| bih = np.zeros((1, 8)).astype(np.float32) | bih = np.zeros((1, 8)).astype(np.float32) | ||||
| w_np = np.concatenate((wih, whh, bih), axis=1).reshape([-1, 1, 1]) | w_np = np.concatenate((wih, whh, bih), axis=1).reshape([-1, 1, 1]) | ||||
| self.w = Parameter(initializer(Tensor(w_np), w_np.shape), name='weight0') | self.w = Parameter(initializer(Tensor(w_np), w_np.shape), name='weight0') | ||||
| self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, | |||||
| has_bias=has_bias, bidirectional=bidirectional, dropout=dropout) | |||||
| self.lstm = StackLSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, | |||||
| has_bias=has_bias, bidirectional=bidirectional, dropout=dropout) | |||||
| self.lstm.weight = ParameterTuple(tuple([self.w])) | self.lstm.weight = ParameterTuple(tuple([self.w])) | ||||
| @ms_function | @ms_function | ||||