You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_lstm_op.py 16 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import math
  16. import pytest
  17. import numpy as np
  18. import mindspore.nn as nn
  19. import mindspore.context as context
  20. from mindspore.common.api import ms_function
  21. from mindspore.common.initializer import initializer
  22. from mindspore.ops import composite as C
  23. from mindspore.ops import operations as P
  24. from mindspore.common.tensor import Tensor
  25. from mindspore.common.parameter import ParameterTuple, Parameter
  26. context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
  27. class StackLSTM(nn.Cell):
  28. """
  29. Stack multi-layers LSTM together.
  30. """
  31. def __init__(self,
  32. input_size,
  33. hidden_size,
  34. num_layers=1,
  35. has_bias=True,
  36. batch_first=False,
  37. dropout=0.0,
  38. bidirectional=False):
  39. super(StackLSTM, self).__init__()
  40. self.num_layers = num_layers
  41. self.batch_first = batch_first
  42. self.transpose = P.Transpose()
  43. # direction number
  44. num_directions = 2 if bidirectional else 1
  45. # input_size list
  46. input_size_list = [input_size]
  47. for i in range(num_layers - 1):
  48. input_size_list.append(hidden_size * num_directions)
  49. # layers
  50. layers = []
  51. for i in range(num_layers):
  52. layers.append(nn.LSTMCell(input_size=input_size_list[i],
  53. hidden_size=hidden_size,
  54. has_bias=has_bias,
  55. batch_first=batch_first,
  56. bidirectional=bidirectional,
  57. dropout=dropout))
  58. # weights
  59. weights = []
  60. for i in range(num_layers):
  61. # weight size
  62. weight_size = (input_size_list[i] + hidden_size) * num_directions * hidden_size * 4
  63. if has_bias:
  64. bias_size = num_directions * hidden_size * 4
  65. weight_size = weight_size + bias_size
  66. # numpy weight
  67. stdv = 1 / math.sqrt(hidden_size)
  68. w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
  69. # lstm weight
  70. weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name="weight" + str(i)))
  71. #
  72. self.lstms = layers
  73. self.weight = ParameterTuple(tuple(weights))
  74. def construct(self, x, hx):
  75. """construct"""
  76. if self.batch_first:
  77. x = self.transpose(x, (1, 0, 2))
  78. # stack lstm
  79. h, c = hx
  80. hn = cn = None
  81. for i in range(self.num_layers):
  82. x, hn, cn, _, _ = self.lstms[i](x, h[i], c[i], self.weight[i])
  83. if self.batch_first:
  84. x = self.transpose(x, (1, 0, 2))
  85. return x, (hn, cn)
  86. class LstmNet(nn.Cell):
  87. def __init__(self, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
  88. super(LstmNet, self).__init__()
  89. num_directions = 1
  90. if bidirectional:
  91. num_directions = 2
  92. self.lstm = StackLSTM(input_size, hidden_size, num_layers, has_bias, bidirectional, dropout)
  93. input_np = np.array([[[0.6755, -1.6607, 0.1367], [0.4276, -0.7850, -0.3758]],
  94. [[-0.6424, -0.6095, 0.6639], [0.7918, 0.4147, -0.5089]],
  95. [[-1.5612, 0.0120, -0.7289], [-0.6656, -0.6626, -0.5883]],
  96. [[-0.9667, -0.6296, -0.7310], [0.1026, -0.6821, -0.4387]],
  97. [[-0.4710, 0.6558, -0.3144], [-0.8449, -0.2184, -0.1806]]
  98. ]).astype(np.float32)
  99. self.x = Tensor(input_np)
  100. self.h = Tensor(np.array([0., 0., 0., 0.]).reshape((num_directions, batch_size, hidden_size)).astype(
  101. np.float32))
  102. self.c = Tensor(np.array([0., 0., 0., 0.]).reshape((num_directions, batch_size, hidden_size)).astype(
  103. np.float32))
  104. self.h = tuple((self.h,))
  105. self.c = tuple((self.c,))
  106. wih = np.array([[3.4021e-01, -4.6622e-01, 4.5117e-01],
  107. [-6.4257e-02, -2.4807e-01, 1.3550e-02], # i
  108. [-3.2140e-01, 5.5578e-01, 6.3589e-01],
  109. [1.6547e-01, -7.9030e-02, -2.0045e-01],
  110. [-6.9863e-01, 5.9773e-01, -3.9062e-01],
  111. [-3.0253e-01, -1.9464e-01, 7.0591e-01],
  112. [-4.0835e-01, 3.6751e-01, 4.7989e-01],
  113. [-5.6894e-01, -5.0359e-01, 4.7491e-01]]).astype(np.float32).reshape([1, -1])
  114. whh = np.array([[-0.4820, -0.2350],
  115. [-0.1195, 0.0519],
  116. [0.2162, -0.1178],
  117. [0.6237, 0.0711],
  118. [0.4511, -0.3961],
  119. [-0.5962, 0.0906],
  120. [0.1867, -0.1225],
  121. [0.1831, 0.0850]]).astype(np.float32).reshape([1, -1])
  122. bih = np.zeros((1, 8)).astype(np.float32)
  123. w_np = np.concatenate((wih, whh, bih), axis=1).reshape([-1, 1, 1])
  124. self.w = Parameter(initializer(Tensor(w_np), w_np.shape), name='w')
  125. self.lstm.weight = ParameterTuple((self.w,))
  126. @ms_function
  127. def construct(self):
  128. return self.lstm(self.x, (self.h, self.c))
  129. @pytest.mark.level0
  130. @pytest.mark.platform_x86_cpu
  131. @pytest.mark.env_onecard
  132. def test_lstm():
  133. seq_len = 5
  134. batch_size = 2
  135. input_size = 3
  136. hidden_size = 2
  137. num_layers = 1
  138. has_bias = True
  139. bidirectional = False
  140. dropout = 0.0
  141. num_directions = 1
  142. if bidirectional:
  143. num_directions = 2
  144. net = LstmNet(batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout)
  145. y, (h, c) = net()
  146. print(y)
  147. print(c)
  148. print(h)
  149. expect_y = [[[-0.17992045, 0.07819052],
  150. [-0.10745212, -0.06291768]],
  151. [[-0.28830513, 0.30579978],
  152. [-0.07570618, -0.08868407]],
  153. [[-0.00814095, 0.16889746],
  154. [0.02814853, -0.11208838]],
  155. [[0.08157863, 0.06088024],
  156. [-0.04227093, -0.11514835]],
  157. [[0.18908429, -0.02963362],
  158. [0.09106826, -0.00602506]]]
  159. expect_h = [[[0.18908429, -0.02963362],
  160. [0.09106826, -0.00602506]]]
  161. expect_c = [[[0.3434288, -0.06561527],
  162. [0.16838229, -0.00972614]]]
  163. diff_y = y.asnumpy() - expect_y
  164. error_y = np.ones([seq_len, batch_size, hidden_size]) * 1.0e-4
  165. assert np.all(diff_y < error_y)
  166. assert np.all(-diff_y < error_y)
  167. diff_h = h.asnumpy() - expect_h
  168. error_h = np.ones([num_layers * num_directions, batch_size, hidden_size]) * 1.0e-4
  169. assert np.all(diff_h < error_h)
  170. assert np.all(-diff_h < error_h)
  171. diff_c = c.asnumpy() - expect_c
  172. error_c = np.ones([num_layers * num_directions, batch_size, hidden_size]) * 1.0e-4
  173. assert np.all(diff_c < error_c)
  174. assert np.all(-diff_c < error_c)
  175. class MultiLayerBiLstmNet(nn.Cell):
  176. def __init__(self, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
  177. super(MultiLayerBiLstmNet, self).__init__()
  178. num_directions = 1
  179. if bidirectional:
  180. num_directions = 2
  181. self.lstm = StackLSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, has_bias=has_bias,
  182. bidirectional=bidirectional, dropout=dropout)
  183. input_np = np.array([[[-0.1887, -0.4144, -0.0235, 0.7489, 0.7522, 0.5969, 0.3342, 1.2198, 0.6786, -0.9404],
  184. [-0.8643, -1.6835, -2.4965, 2.8093, 0.1741, 0.2707, 0.7387, -0.0939, -1.7990, 0.4765]],
  185. [[-0.5963, -1.2598, -0.7226, 1.1365, -1.7320, -0.7302, 0.1221, -0.2111, -1.6173, -0.0706],
  186. [0.8964, 0.1737, -1.0077, -0.1389, 0.4889, 0.4391, 0.7911, 0.3614, -1.9533, -0.9936]],
  187. [[0.3260, -1.3312, 0.0601, 1.0726, -1.6010, -1.8733, -1.5775, 1.1579, -0.8801, -0.5742],
  188. [-2.2998, -0.6344, -0.5409, -0.9221, -0.6500, 0.1206, 1.5215, 0.7517, 1.3691, 2.0021]],
  189. [[-0.1245, -0.3690, 2.1193, 1.3852, -0.1841, -0.8899, -0.3646, -0.8575, -0.3131, 0.2026],
  190. [1.0218, -1.4331, 0.1744, 0.5442, -0.7808, 0.2527, 0.1566, 1.1484, -0.7766, -0.6747]],
  191. [[-0.6752, 0.9906, -0.4973, 0.3471, -0.1202, -0.4213, 2.0213, 0.0441, 0.9016, 1.0365],
  192. [1.2223, -1.3248, 0.1207, -0.8256, 0.1816, 0.7057, -0.3105, 0.5713, 0.2804,
  193. -1.0685]]]).astype(np.float32)
  194. self.x = Tensor(input_np)
  195. self.h0 = Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32))
  196. self.c0 = Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32))
  197. self.h1 = Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32))
  198. self.c1 = Tensor(np.ones((num_directions, batch_size, hidden_size)).astype(np.float32))
  199. self.h = tuple((self.h0, self.h1))
  200. self.c = tuple((self.c0, self.c1))
  201. input_size_list = [input_size, hidden_size * num_directions]
  202. weights = []
  203. bias_size = 0 if not has_bias else num_directions * hidden_size * 4
  204. for i in range(num_layers):
  205. weight_size = (input_size_list[i] + hidden_size) * num_directions * hidden_size * 4
  206. w_np = np.ones([weight_size, 1, 1]).astype(np.float32) * 0.02
  207. if has_bias:
  208. bias_np = np.zeros([bias_size, 1, 1]).astype(np.float32)
  209. w_np = np.concatenate([w_np, bias_np], axis=0)
  210. weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name='weight' + str(i)))
  211. self.lstm.weight = weights
  212. @ms_function
  213. def construct(self):
  214. return self.lstm(self.x, (self.h, self.c))
  215. @pytest.mark.level0
  216. @pytest.mark.platform_x86_cpu
  217. @pytest.mark.env_onecard
  218. def test_multi_layer_bilstm():
  219. batch_size = 2
  220. input_size = 10
  221. hidden_size = 2
  222. num_layers = 2
  223. has_bias = True
  224. bidirectional = True
  225. dropout = 0.0
  226. net = MultiLayerBiLstmNet(batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional,
  227. dropout)
  228. y, (h, c) = net()
  229. print(y)
  230. print(h)
  231. print(c)
  232. class Grad(nn.Cell):
  233. def __init__(self, network):
  234. super(Grad, self).__init__()
  235. self.network = network
  236. self.weights = ParameterTuple(network.trainable_params())
  237. self.grad = C.GradOperation(get_by_list=True,
  238. sens_param=True)
  239. @ms_function
  240. def construct(self, output_grad):
  241. weights = self.weights
  242. grads = self.grad(self.network, weights)(output_grad)
  243. return grads
  244. class Net(nn.Cell):
  245. def __init__(self, seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
  246. super(Net, self).__init__()
  247. num_directions = 1
  248. if bidirectional:
  249. num_directions = 2
  250. input_np = np.array([[[0.6755, -1.6607, 0.1367], [0.4276, -0.7850, -0.3758]],
  251. [[-0.6424, -0.6095, 0.6639], [0.7918, 0.4147, -0.5089]],
  252. [[-1.5612, 0.0120, -0.7289], [-0.6656, -0.6626, -0.5883]],
  253. [[-0.9667, -0.6296, -0.7310], [0.1026, -0.6821, -0.4387]],
  254. [[-0.4710, 0.6558, -0.3144], [-0.8449, -0.2184, -0.1806]]
  255. ]).astype(np.float32)
  256. self.x = Parameter(initializer(Tensor(input_np), [seq_len, batch_size, input_size]), name='x')
  257. self.hlist = []
  258. self.clist = []
  259. self.hlist.append(Parameter(initializer(
  260. Tensor(
  261. np.array([0.1, 0.1, 0.1, 0.1]).reshape((num_directions, batch_size, hidden_size)).astype(
  262. np.float32)),
  263. [num_directions, batch_size, hidden_size]), name='h'))
  264. self.clist.append(Parameter(initializer(
  265. Tensor(
  266. np.array([0.2, 0.2, 0.2, 0.2]).reshape((num_directions, batch_size, hidden_size)).astype(
  267. np.float32)),
  268. [num_directions, batch_size, hidden_size]), name='c'))
  269. self.h = ParameterTuple(tuple(self.hlist))
  270. self.c = ParameterTuple(tuple(self.clist))
  271. wih = np.array([[3.4021e-01, -4.6622e-01, 4.5117e-01],
  272. [-6.4257e-02, -2.4807e-01, 1.3550e-02], # i
  273. [-3.2140e-01, 5.5578e-01, 6.3589e-01],
  274. [1.6547e-01, -7.9030e-02, -2.0045e-01],
  275. [-6.9863e-01, 5.9773e-01, -3.9062e-01],
  276. [-3.0253e-01, -1.9464e-01, 7.0591e-01],
  277. [-4.0835e-01, 3.6751e-01, 4.7989e-01],
  278. [-5.6894e-01, -5.0359e-01, 4.7491e-01]]).astype(np.float32).reshape([1, -1])
  279. whh = np.array([[-0.4820, -0.2350],
  280. [-0.1195, 0.0519],
  281. [0.2162, -0.1178],
  282. [0.6237, 0.0711],
  283. [0.4511, -0.3961],
  284. [-0.5962, 0.0906],
  285. [0.1867, -0.1225],
  286. [0.1831, 0.0850]]).astype(np.float32).reshape([1, -1])
  287. bih = np.zeros((1, 8)).astype(np.float32)
  288. w_np = np.concatenate((wih, whh, bih), axis=1).reshape([-1, 1, 1])
  289. self.w = Parameter(initializer(Tensor(w_np), w_np.shape), name='weight0')
  290. self.lstm = StackLSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
  291. has_bias=has_bias, bidirectional=bidirectional, dropout=dropout)
  292. self.lstm.weight = ParameterTuple(tuple([self.w]))
  293. @ms_function
  294. def construct(self):
  295. return self.lstm(self.x, (self.h, self.c))[0]
  296. @pytest.mark.level0
  297. @pytest.mark.platform_x86_cpu
  298. @pytest.mark.env_onecard
  299. def test_grad():
  300. seq_len = 5
  301. batch_size = 2
  302. input_size = 3
  303. hidden_size = 2
  304. num_layers = 1
  305. has_bias = False
  306. bidirectional = False
  307. dropout = 0.0
  308. net = Grad(Net(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout))
  309. dy = np.array([[[-3.5471e-01, 7.0540e-01],
  310. [2.7161e-01, 1.0865e+00]],
  311. [[-4.2431e-01, 1.4955e+00],
  312. [-4.0418e-01, -2.3282e-01]],
  313. [[-1.3654e+00, 1.9251e+00],
  314. [-4.6481e-01, 1.3138e+00]],
  315. [[1.2914e+00, -2.3753e-01],
  316. [5.3589e-01, -1.0981e-01]],
  317. [[-1.6032e+00, -1.8818e-01],
  318. [1.0065e-01, 9.2045e-01]]]).astype(np.float32)
  319. dx, dhx, dcx, dw = net(Tensor(dy))
  320. print(dx)
  321. print(dhx)
  322. print(dcx)
  323. print(dw)
  324. test_multi_layer_bilstm()
  325. test_lstm()
  326. test_grad()