You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lstm.py 13 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """LSTM."""
  16. import math
  17. import numpy as np
  18. from mindspore import Tensor, nn, context, Parameter, ParameterTuple
  19. from mindspore.common.initializer import initializer
  20. from mindspore.ops import operations as P
  21. import mindspore.ops.functional as F
  22. import mindspore.common.dtype as mstype
  23. STACK_LSTM_DEVICE = ["CPU"]
  24. # Initialize short-term memory (h) and long-term memory (c) to 0
  25. def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
  26. """init default input."""
  27. num_directions = 2 if bidirectional else 1
  28. h = Tensor(np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
  29. c = Tensor(np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
  30. return h, c
  31. def stack_lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
  32. """init default input."""
  33. num_directions = 2 if bidirectional else 1
  34. h_list = c_list = []
  35. for _ in range(num_layers):
  36. h_list.append(Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)))
  37. c_list.append(Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)))
  38. h, c = tuple(h_list), tuple(c_list)
  39. return h, c
  40. def stack_lstm_default_state_ascend(batch_size, hidden_size, num_layers, bidirectional):
  41. """init default input."""
  42. h_list = c_list = []
  43. for _ in range(num_layers):
  44. h_fw = Tensor(np.zeros((1, batch_size, hidden_size)).astype(np.float16))
  45. c_fw = Tensor(np.zeros((1, batch_size, hidden_size)).astype(np.float16))
  46. h_i = [h_fw]
  47. c_i = [c_fw]
  48. if bidirectional:
  49. h_bw = Tensor(np.zeros((1, batch_size, hidden_size)).astype(np.float16))
  50. c_bw = Tensor(np.zeros((1, batch_size, hidden_size)).astype(np.float16))
  51. h_i.append(h_bw)
  52. c_i.append(c_bw)
  53. h_list.append(h_i)
  54. c_list.append(c_i)
  55. h, c = tuple(h_list), tuple(c_list)
  56. return h, c
  57. class StackLSTM(nn.Cell):
  58. """
  59. Stack multi-layers LSTM together.
  60. """
  61. def __init__(self,
  62. input_size,
  63. hidden_size,
  64. num_layers=1,
  65. has_bias=True,
  66. batch_first=False,
  67. dropout=0.0,
  68. bidirectional=False):
  69. super(StackLSTM, self).__init__()
  70. self.num_layers = num_layers
  71. self.batch_first = batch_first
  72. self.transpose = P.Transpose()
  73. # direction number
  74. num_directions = 2 if bidirectional else 1
  75. # input_size list
  76. input_size_list = [input_size]
  77. for i in range(num_layers - 1):
  78. input_size_list.append(hidden_size * num_directions)
  79. # layers
  80. layers = []
  81. for i in range(num_layers):
  82. layers.append(nn.LSTMCell(input_size=input_size_list[i],
  83. hidden_size=hidden_size,
  84. has_bias=has_bias,
  85. batch_first=batch_first,
  86. bidirectional=bidirectional,
  87. dropout=dropout))
  88. # weights
  89. weights = []
  90. for i in range(num_layers):
  91. # weight size
  92. weight_size = (input_size_list[i] + hidden_size) * num_directions * hidden_size * 4
  93. if has_bias:
  94. bias_size = num_directions * hidden_size * 4
  95. weight_size = weight_size + bias_size
  96. # numpy weight
  97. stdv = 1 / math.sqrt(hidden_size)
  98. w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
  99. # lstm weight
  100. weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name="weight" + str(i)))
  101. #
  102. self.lstms = layers
  103. self.weight = ParameterTuple(tuple(weights))
  104. def construct(self, x, hx):
  105. """construct"""
  106. if self.batch_first:
  107. x = self.transpose(x, (1, 0, 2))
  108. # stack lstm
  109. h, c = hx
  110. hn = cn = None
  111. for i in range(self.num_layers):
  112. x, hn, cn, _, _ = self.lstms[i](x, h[i], c[i], self.weight[i])
  113. if self.batch_first:
  114. x = self.transpose(x, (1, 0, 2))
  115. return x, (hn, cn)
  116. class LSTM_Ascend(nn.Cell):
  117. """ LSTM in Ascend. """
  118. def __init__(self, bidirectional=False):
  119. super(LSTM_Ascend, self).__init__()
  120. self.bidirectional = bidirectional
  121. self.dynamic_rnn = P.DynamicRNN(forget_bias=0.0)
  122. self.reverseV2 = P.ReverseV2(axis=[0])
  123. self.concat = P.Concat(2)
  124. def construct(self, x, h, c, w_f, b_f, w_b=None, b_b=None):
  125. """construct"""
  126. x = F.cast(x, mstype.float16)
  127. if self.bidirectional:
  128. y1, h1, c1, _, _, _, _, _ = self.dynamic_rnn(x, w_f, b_f, None, h[0], c[0])
  129. r_x = self.reverseV2(x)
  130. y2, h2, c2, _, _, _, _, _ = self.dynamic_rnn(r_x, w_b, b_b, None, h[1], c[1])
  131. y2 = self.reverseV2(y2)
  132. output = self.concat((y1, y2))
  133. hn = self.concat((h1, h2))
  134. cn = self.concat((c1, c2))
  135. return output, (hn, cn)
  136. y1, h1, c1, _, _, _, _, _ = self.dynamic_rnn(x, w_f, b_f, None, h[0], c[0])
  137. return y1, (h1, c1)
  138. class StackLSTMAscend(nn.Cell):
  139. """ Stack multi-layers LSTM together. """
  140. def __init__(self,
  141. input_size,
  142. hidden_size,
  143. num_layers=1,
  144. has_bias=True,
  145. batch_first=False,
  146. dropout=0.0,
  147. bidirectional=False):
  148. super(StackLSTMAscend, self).__init__()
  149. self.num_layers = num_layers
  150. self.batch_first = batch_first
  151. self.bidirectional = bidirectional
  152. self.transpose = P.Transpose()
  153. # input_size list
  154. input_size_list = [input_size]
  155. for i in range(num_layers - 1):
  156. input_size_list.append(hidden_size * 2)
  157. #weights, bias and layers init
  158. weights_fw = []
  159. weights_bw = []
  160. bias_fw = []
  161. bias_bw = []
  162. stdv = 1 / math.sqrt(hidden_size)
  163. for i in range(num_layers):
  164. # forward weight init
  165. w_np_fw = np.random.uniform(-stdv,
  166. stdv,
  167. (input_size_list[i] + hidden_size, hidden_size * 4)).astype(np.float32)
  168. w_fw = Parameter(initializer(Tensor(w_np_fw), w_np_fw.shape), name="w_fw_layer" + str(i))
  169. weights_fw.append(w_fw)
  170. # forward bias init
  171. if has_bias:
  172. b_fw = np.random.uniform(-stdv, stdv, (hidden_size * 4)).astype(np.float32)
  173. b_fw = Parameter(initializer(Tensor(b_fw), b_fw.shape), name="b_fw_layer" + str(i))
  174. else:
  175. b_fw = np.zeros((hidden_size * 4)).astype(np.float32)
  176. b_fw = Parameter(initializer(Tensor(b_fw), b_fw.shape), name="b_fw_layer" + str(i))
  177. bias_fw.append(b_fw)
  178. if bidirectional:
  179. # backward weight init
  180. w_np_bw = np.random.uniform(-stdv,
  181. stdv,
  182. (input_size_list[i] + hidden_size, hidden_size * 4)).astype(np.float32)
  183. w_bw = Parameter(initializer(Tensor(w_np_bw), w_np_bw.shape), name="w_bw_layer" + str(i))
  184. weights_bw.append(w_bw)
  185. # backward bias init
  186. if has_bias:
  187. b_bw = np.random.uniform(-stdv, stdv, (hidden_size * 4)).astype(np.float32)
  188. b_bw = Parameter(initializer(Tensor(b_bw), b_bw.shape), name="b_bw_layer" + str(i))
  189. else:
  190. b_bw = np.zeros((hidden_size * 4)).astype(np.float32)
  191. b_bw = Parameter(initializer(Tensor(b_bw), b_bw.shape), name="b_bw_layer" + str(i))
  192. bias_bw.append(b_bw)
  193. # layer init
  194. self.lstm = LSTM_Ascend(bidirectional=bidirectional).to_float(mstype.float16)
  195. self.weight_fw = ParameterTuple(tuple(weights_fw))
  196. self.weight_bw = ParameterTuple(tuple(weights_bw))
  197. self.bias_fw = ParameterTuple(tuple(bias_fw))
  198. self.bias_bw = ParameterTuple(tuple(bias_bw))
  199. def construct(self, x, hx):
  200. """construct"""
  201. x = F.cast(x, mstype.float16)
  202. if self.batch_first:
  203. x = self.transpose(x, (1, 0, 2))
  204. # stack lstm
  205. h, c = hx
  206. hn = cn = None
  207. for i in range(self.num_layers):
  208. if self.bidirectional:
  209. x, (hn, cn) = self.lstm(x,
  210. h[i],
  211. c[i],
  212. self.weight_fw[i],
  213. self.bias_fw[i],
  214. self.weight_bw[i],
  215. self.bias_bw[i])
  216. else:
  217. x, (hn, cn) = self.lstm(x, h[i], c[i], self.weight_fw[i], self.bias_fw[i])
  218. if self.batch_first:
  219. x = self.transpose(x, (1, 0, 2))
  220. x = F.cast(x, mstype.float32)
  221. hn = F.cast(x, mstype.float32)
  222. cn = F.cast(x, mstype.float32)
  223. return x, (hn, cn)
  224. class SentimentNet(nn.Cell):
  225. """Sentiment network structure."""
  226. def __init__(self,
  227. vocab_size,
  228. embed_size,
  229. num_hiddens,
  230. num_layers,
  231. bidirectional,
  232. num_classes,
  233. weight,
  234. batch_size):
  235. super(SentimentNet, self).__init__()
  236. # Mapp words to vectors
  237. self.embedding = nn.Embedding(vocab_size,
  238. embed_size,
  239. embedding_table=weight)
  240. self.embedding.embedding_table.requires_grad = False
  241. self.trans = P.Transpose()
  242. self.perm = (1, 0, 2)
  243. if context.get_context("device_target") in STACK_LSTM_DEVICE:
  244. # stack lstm by user
  245. self.encoder = StackLSTM(input_size=embed_size,
  246. hidden_size=num_hiddens,
  247. num_layers=num_layers,
  248. has_bias=True,
  249. bidirectional=bidirectional,
  250. dropout=0.0)
  251. self.h, self.c = stack_lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional)
  252. elif context.get_context("device_target") == "GPU":
  253. # standard lstm
  254. self.encoder = nn.LSTM(input_size=embed_size,
  255. hidden_size=num_hiddens,
  256. num_layers=num_layers,
  257. has_bias=True,
  258. bidirectional=bidirectional,
  259. dropout=0.0)
  260. self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional)
  261. else:
  262. self.encoder = StackLSTMAscend(input_size=embed_size,
  263. hidden_size=num_hiddens,
  264. num_layers=num_layers,
  265. has_bias=True,
  266. bidirectional=bidirectional)
  267. self.h, self.c = stack_lstm_default_state_ascend(batch_size, num_hiddens, num_layers, bidirectional)
  268. self.concat = P.Concat(1)
  269. self.squeeze = P.Squeeze(axis=0)
  270. if bidirectional:
  271. self.decoder = nn.Dense(num_hiddens * 4, num_classes)
  272. else:
  273. self.decoder = nn.Dense(num_hiddens * 2, num_classes)
  274. def construct(self, inputs):
  275. # input:(64,500,300)
  276. embeddings = self.embedding(inputs)
  277. embeddings = self.trans(embeddings, self.perm)
  278. output, _ = self.encoder(embeddings, (self.h, self.c))
  279. # states[i] size(64,200) -> encoding.size(64,400)
  280. encoding = self.concat((self.squeeze(output[0:1:1]), self.squeeze(output[499:500:1])))
  281. outputs = self.decoder(encoding)
  282. return outputs