You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lstm.py 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """LSTM."""
  16. import math
  17. import numpy as np
  18. from mindspore import Tensor, nn, context, Parameter, ParameterTuple
  19. from mindspore.common.initializer import initializer
  20. from mindspore.ops import operations as P
  21. STACK_LSTM_DEVICE = ["CPU"]
  22. # Initialize short-term memory (h) and long-term memory (c) to 0
  23. def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
  24. """init default input."""
  25. num_directions = 2 if bidirectional else 1
  26. h = Tensor(np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
  27. c = Tensor(np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
  28. return h, c
  29. def stack_lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
  30. """init default input."""
  31. num_directions = 2 if bidirectional else 1
  32. h_list = c_list = []
  33. for _ in range(num_layers):
  34. h_list.append(Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)))
  35. c_list.append(Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)))
  36. h, c = tuple(h_list), tuple(c_list)
  37. return h, c
  38. class StackLSTM(nn.Cell):
  39. """
  40. Stack multi-layers LSTM together.
  41. """
  42. def __init__(self,
  43. input_size,
  44. hidden_size,
  45. num_layers=1,
  46. has_bias=True,
  47. batch_first=False,
  48. dropout=0.0,
  49. bidirectional=False):
  50. super(StackLSTM, self).__init__()
  51. self.num_layers = num_layers
  52. self.batch_first = batch_first
  53. self.transpose = P.Transpose()
  54. # direction number
  55. num_directions = 2 if bidirectional else 1
  56. # input_size list
  57. input_size_list = [input_size]
  58. for i in range(num_layers - 1):
  59. input_size_list.append(hidden_size * num_directions)
  60. # layers
  61. layers = []
  62. for i in range(num_layers):
  63. layers.append(nn.LSTMCell(input_size=input_size_list[i],
  64. hidden_size=hidden_size,
  65. has_bias=has_bias,
  66. batch_first=batch_first,
  67. bidirectional=bidirectional,
  68. dropout=dropout))
  69. # weights
  70. weights = []
  71. for i in range(num_layers):
  72. # weight size
  73. weight_size = (input_size_list[i] + hidden_size) * num_directions * hidden_size * 4
  74. if has_bias:
  75. bias_size = num_directions * hidden_size * 4
  76. weight_size = weight_size + bias_size
  77. # numpy weight
  78. stdv = 1 / math.sqrt(hidden_size)
  79. w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
  80. # lstm weight
  81. weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name="weight" + str(i)))
  82. #
  83. self.lstms = layers
  84. self.weight = ParameterTuple(tuple(weights))
  85. def construct(self, x, hx):
  86. """construct"""
  87. if self.batch_first:
  88. x = self.transpose(x, (1, 0, 2))
  89. # stack lstm
  90. h, c = hx
  91. hn = cn = None
  92. for i in range(self.num_layers):
  93. x, hn, cn, _, _ = self.lstms[i](x, h[i], c[i], self.weight[i])
  94. if self.batch_first:
  95. x = self.transpose(x, (1, 0, 2))
  96. return x, (hn, cn)
  97. class SentimentNet(nn.Cell):
  98. """Sentiment network structure."""
  99. def __init__(self,
  100. vocab_size,
  101. embed_size,
  102. num_hiddens,
  103. num_layers,
  104. bidirectional,
  105. num_classes,
  106. weight,
  107. batch_size):
  108. super(SentimentNet, self).__init__()
  109. # Mapp words to vectors
  110. self.embedding = nn.Embedding(vocab_size,
  111. embed_size,
  112. embedding_table=weight)
  113. self.embedding.embedding_table.requires_grad = False
  114. self.trans = P.Transpose()
  115. self.perm = (1, 0, 2)
  116. if context.get_context("device_target") in STACK_LSTM_DEVICE:
  117. # stack lstm by user
  118. self.encoder = StackLSTM(input_size=embed_size,
  119. hidden_size=num_hiddens,
  120. num_layers=num_layers,
  121. has_bias=True,
  122. bidirectional=bidirectional,
  123. dropout=0.0)
  124. self.h, self.c = stack_lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional)
  125. else:
  126. # standard lstm
  127. self.encoder = nn.LSTM(input_size=embed_size,
  128. hidden_size=num_hiddens,
  129. num_layers=num_layers,
  130. has_bias=True,
  131. bidirectional=bidirectional,
  132. dropout=0.0)
  133. self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional)
  134. self.concat = P.Concat(1)
  135. if bidirectional:
  136. self.decoder = nn.Dense(num_hiddens * 4, num_classes)
  137. else:
  138. self.decoder = nn.Dense(num_hiddens * 2, num_classes)
  139. def construct(self, inputs):
  140. # input:(64,500,300)
  141. embeddings = self.embedding(inputs)
  142. embeddings = self.trans(embeddings, self.perm)
  143. output, _ = self.encoder(embeddings, (self.h, self.c))
  144. # states[i] size(64,200) -> encoding.size(64,400)
  145. encoding = self.concat((output[0], output[499]))
  146. outputs = self.decoder(encoding)
  147. return outputs