You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cont_grad.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test control ops """
  16. import numpy as np
  17. from mindspore import dtype as ms
  18. from mindspore import Tensor
  19. from mindspore import context
  20. from mindspore import nn
  21. from mindspore.common.parameter import Parameter, ParameterTuple
  22. from mindspore.ops import composite as C
  23. from mindspore.ops import operations as P
  24. # from tests.vm_impl.math_ops_vm_impl import *
  25. # from tests.vm_impl.vm_interface import *
  26. # from tests.vm_impl import *
  27. # context.set_context(save_graphs=True)
  28. def test_while_forward():
  29. class MyWhileNet(nn.Cell):
  30. def __init__(self):
  31. super().__init__()
  32. self.max = P.ReduceMax()
  33. def construct(self, idx, end, x):
  34. while idx < end:
  35. part = x[idx, :, :]
  36. max_num = self.max(part)
  37. x[idx, :, 0:2] = max_num
  38. idx = idx + 1
  39. return x
  40. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  41. net = MyWhileNet()
  42. idx = Tensor(np.array(0), dtype=ms.int32)
  43. end = Tensor(np.array(2), dtype=ms.int32)
  44. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  45. net(idx, end, x)
  46. def test_while_grad():
  47. class MyWhileNet(nn.Cell):
  48. def __init__(self):
  49. super().__init__()
  50. self.max = P.ReduceMax()
  51. def construct(self, idx, end, x):
  52. while idx < end:
  53. part = x[idx, :, :]
  54. max_num = self.max(part)
  55. x[idx, :, 0:2] = max_num
  56. idx = idx + 1
  57. return x
  58. class GradNet(nn.Cell):
  59. def __init__(self, net):
  60. super(GradNet, self).__init__()
  61. self.net = net
  62. def construct(self, *inputs):
  63. return C.grad_all(self.net)(*inputs)
  64. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  65. while_net = MyWhileNet()
  66. net = GradNet(while_net)
  67. idx = Tensor(np.array(0), dtype=ms.int32)
  68. end = Tensor(np.array(2), dtype=ms.int32)
  69. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  70. net(idx, end, x)
  71. def test_while_with_param_forward():
  72. class MyWhileNet(nn.Cell):
  73. def __init__(self):
  74. super().__init__()
  75. self.max = P.ReduceMax()
  76. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  77. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  78. def construct(self, idx, end, x):
  79. out = self.zero
  80. while idx < end:
  81. part = x[idx, :, :]
  82. max_num = self.max(part)
  83. x[idx, :, 0:2] = max_num
  84. out = out + x + self.param
  85. idx = idx + 1
  86. return out
  87. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  88. net = MyWhileNet()
  89. idx = Tensor(np.array(0), dtype=ms.int32)
  90. end = Tensor(np.array(2), dtype=ms.int32)
  91. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  92. net(idx, end, x)
  93. def test_while_endless_case():
  94. """endless case when optmization"""
  95. class MyWhileNet(nn.Cell):
  96. def __init__(self):
  97. super().__init__()
  98. self.max = P.ReduceMax()
  99. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  100. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  101. def construct(self, idx, end, x):
  102. out = self.zero
  103. while idx < end:
  104. part = x[idx, :, :]
  105. out = out + part
  106. idx = idx + 1
  107. return out
  108. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  109. net = MyWhileNet()
  110. idx = Tensor(np.array(0), dtype=ms.int32)
  111. end = Tensor(np.array(2), dtype=ms.int32)
  112. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  113. net(idx, end, x)
  114. def test_while_with_param_grad():
  115. class MyWhileNet(nn.Cell):
  116. def __init__(self):
  117. super().__init__()
  118. self.max = P.ReduceMax()
  119. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  120. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  121. def construct(self, idx, end, x):
  122. out = self.zero
  123. while idx < end:
  124. part = x[idx, :, :]
  125. max_num = self.max(part)
  126. x[idx, :, 0:2] = max_num
  127. out = out + x + self.param
  128. idx = idx + 1
  129. return out
  130. class GradNet(nn.Cell):
  131. def __init__(self, net):
  132. super(GradNet, self).__init__()
  133. self.net = net
  134. self.weights = ParameterTuple(net.trainable_params())
  135. def construct(self, a, b, c):
  136. return C.grad_by_list(self.net, self.weights)(a, b, c)
  137. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  138. while_net = MyWhileNet()
  139. net = GradNet(while_net)
  140. idx = Tensor(np.array(0), dtype=ms.int32)
  141. end = Tensor(np.array(2), dtype=ms.int32)
  142. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  143. net(idx, end, x)
  144. def test_while_with_param_forward_with_const_branch():
  145. class MyWhileNet(nn.Cell):
  146. def __init__(self):
  147. super().__init__()
  148. self.max = P.ReduceMax()
  149. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  150. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  151. self.reduce = P.ReduceSum()
  152. def construct(self, idx, end, x):
  153. out = self.zero
  154. while idx < end:
  155. if 2 > 1:
  156. out = out + self.param
  157. else:
  158. out = out + idx + self.param
  159. idx = idx + 1
  160. return out
  161. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  162. while_net = MyWhileNet()
  163. net = while_net
  164. idx = Tensor(np.array(0), dtype=ms.int32)
  165. end = Tensor(np.array(4), dtype=ms.int32)
  166. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  167. net(idx, end, x)
  168. def test_while_opt_endless():
  169. """endless during optimization case"""
  170. class MyWhileNet(nn.Cell):
  171. def __init__(self):
  172. super().__init__()
  173. self.max = P.ReduceMax()
  174. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  175. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  176. self.reduce = P.ReduceSum()
  177. self.addn = P.AddN()
  178. def construct(self, idx, end, x):
  179. addn1 = self.addn((x, x, x))
  180. out = addn1
  181. while idx < end:
  182. out = self.addn((out, addn1))
  183. idx = idx + 1
  184. out = self.addn((out, x))
  185. return out
  186. class GradNet(nn.Cell):
  187. def __init__(self, net):
  188. super(GradNet, self).__init__()
  189. self.net = net
  190. def construct(self, *inputs):
  191. return C.grad_all(self.net)(*inputs)
  192. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  193. while_net = MyWhileNet()
  194. net = GradNet(while_net)
  195. idx = Tensor(np.array(0), dtype=ms.int32)
  196. end = Tensor(np.array(4), dtype=ms.int32)
  197. x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32)
  198. net(idx, end, x)
  199. def test_no_while_call():
  200. class MyWhileNet(nn.Cell):
  201. def __init__(self):
  202. super().__init__()
  203. self.max = P.ReduceMax()
  204. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  205. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  206. self.reduce = P.ReduceSum()
  207. def construct(self, idx, end, x):
  208. out = self.zero
  209. if 2 > 1:
  210. out = out + self.param
  211. else:
  212. out = out + idx + self.param
  213. return out
  214. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  215. while_net = MyWhileNet()
  216. net = while_net
  217. idx = Tensor(np.array(0), dtype=ms.int32)
  218. end = Tensor(np.array(4), dtype=ms.int32)
  219. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  220. net(idx, end, x)
  221. def test_while_with_param_grad_with_const_branch():
  222. class MyWhileNet(nn.Cell):
  223. def __init__(self):
  224. super().__init__()
  225. self.max = P.ReduceMax()
  226. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  227. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  228. self.reduce = P.ReduceSum()
  229. def construct(self, idx, end, x):
  230. out = self.zero
  231. while idx < end:
  232. if 2 > 1:
  233. out = out + self.param
  234. else:
  235. out = out + idx + self.param
  236. idx = idx + 1
  237. return out
  238. class GradNet(nn.Cell):
  239. def __init__(self, net):
  240. super(GradNet, self).__init__()
  241. self.net = net
  242. self.weights = ParameterTuple(net.trainable_params())
  243. def construct(self, a, b, c):
  244. return C.grad_by_list(self.net, self.weights)(a, b, c)
  245. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  246. while_net = MyWhileNet()
  247. net = GradNet(while_net)
  248. idx = Tensor(np.array(0), dtype=ms.int32)
  249. end = Tensor(np.array(4), dtype=ms.int32)
  250. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  251. net(idx, end, x)
  252. def test_for_while_with_param_grad_with_const_branch():
  253. class MyWhileNet(nn.Cell):
  254. def __init__(self):
  255. super().__init__()
  256. self.max = P.ReduceMax()
  257. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  258. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  259. self.reduce = P.ReduceSum()
  260. self.start = Tensor(np.array(0), dtype=ms.int32)
  261. def construct(self, idx, end, x):
  262. out = self.zero
  263. for _ in range(0, 2):
  264. idx = self.start
  265. while idx < end:
  266. if 2 > 1:
  267. out = out + self.param
  268. else:
  269. out = out + idx + self.param
  270. idx = idx + 1
  271. return out
  272. class GradNet(nn.Cell):
  273. def __init__(self, net):
  274. super(GradNet, self).__init__()
  275. self.net = net
  276. self.weights = ParameterTuple(net.trainable_params())
  277. def construct(self, a, b, c):
  278. return C.grad_by_list(self.net, self.weights)(a, b, c)
  279. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  280. while_net = MyWhileNet()
  281. net = GradNet(while_net)
  282. idx = Tensor(np.array(0), dtype=ms.int32)
  283. end = Tensor(np.array(4), dtype=ms.int32)
  284. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  285. net(idx, end, x)
  286. def test_for_while_with_param_grad_basic():
  287. class MyWhileNet(nn.Cell):
  288. def __init__(self):
  289. super().__init__()
  290. self.max = P.ReduceMax()
  291. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  292. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  293. self.reduce = P.ReduceSum()
  294. self.start = Tensor(np.array(0), dtype=ms.int32)
  295. def construct(self, idx, end, x):
  296. out = self.zero
  297. for _ in range(0, 2):
  298. idx = self.start
  299. while idx < end:
  300. out = out + self.param
  301. idx = idx + 1
  302. return out
  303. class GradNet(nn.Cell):
  304. def __init__(self, net):
  305. super(GradNet, self).__init__()
  306. self.net = net
  307. self.weights = ParameterTuple(net.trainable_params())
  308. def construct(self, a, b, c):
  309. return C.grad_by_list(self.net, self.weights)(a, b, c)
  310. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  311. while_net = MyWhileNet()
  312. net = GradNet(while_net)
  313. idx = Tensor(np.array(0), dtype=ms.int32)
  314. end = Tensor(np.array(4), dtype=ms.int32)
  315. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  316. net(idx, end, x)
  317. def test_for_while_with_param_grad_normal():
  318. class MyWhileNet(nn.Cell):
  319. def __init__(self):
  320. super().__init__()
  321. self.max = P.ReduceMax()
  322. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  323. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  324. self.reduce = P.ReduceSum()
  325. self.start = Tensor(np.array(0), dtype=ms.int32)
  326. def construct(self, idx, end, x):
  327. out = x
  328. for _ in range(0, 2):
  329. idx = self.start
  330. while idx < end:
  331. out = out + self.param
  332. idx = idx + 1
  333. return out
  334. class GradNet(nn.Cell):
  335. def __init__(self, net):
  336. super(GradNet, self).__init__()
  337. self.net = net
  338. self.weights = ParameterTuple(net.trainable_params())
  339. def construct(self, a, b, c):
  340. return C.grad_by_list(self.net, self.weights)(a, b, c)
  341. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  342. while_net = MyWhileNet()
  343. net = GradNet(while_net)
  344. idx = Tensor(np.array(0), dtype=ms.int32)
  345. end = Tensor(np.array(4), dtype=ms.int32)
  346. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  347. net(idx, end, x)
  348. def test_while_with_param_basic_grad():
  349. class MyWhileNet(nn.Cell):
  350. def __init__(self):
  351. super().__init__()
  352. self.max = P.ReduceMax()
  353. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  354. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  355. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  356. def construct(self, idx, end, x):
  357. out = self.zero
  358. while idx < end:
  359. out = out + self.param
  360. idx = idx + 1
  361. return out + self.param
  362. class GradNet(nn.Cell):
  363. def __init__(self, net):
  364. super(GradNet, self).__init__()
  365. self.net = net
  366. self.weights = ParameterTuple(net.trainable_params())
  367. def construct(self, a, b, c):
  368. return C.grad_by_list(self.net, self.weights)(a, b, c)
  369. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  370. while_net = MyWhileNet()
  371. net = GradNet(while_net)
  372. idx = Tensor(np.array(0), dtype=ms.int32)
  373. end = Tensor(np.array(3), dtype=ms.int32)
  374. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  375. net(idx, end, x)
  376. def test_while_with_param_basic_grad_mul():
  377. class MyWhileNet(nn.Cell):
  378. def __init__(self):
  379. super().__init__()
  380. self.max = P.ReduceMax()
  381. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  382. self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32)
  383. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  384. def construct(self, idx, end, x):
  385. out = self.zero
  386. while idx < end:
  387. out = out * self.param
  388. idx = idx + 1
  389. return out + self.param
  390. class GradNet(nn.Cell):
  391. def __init__(self, net):
  392. super(GradNet, self).__init__()
  393. self.net = net
  394. self.weights = ParameterTuple(net.trainable_params())
  395. def construct(self, a, b, c):
  396. return C.grad_by_list(self.net, self.weights)(a, b, c)
  397. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  398. while_net = MyWhileNet()
  399. net = GradNet(while_net)
  400. idx = Tensor(np.array(0), dtype=ms.int32)
  401. end = Tensor(np.array(3), dtype=ms.int32)
  402. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  403. net(idx, end, x)
  404. def test_while_with_param_basic_grad_two():
  405. class MyWhileNet(nn.Cell):
  406. def __init__(self):
  407. super().__init__()
  408. self.max = P.ReduceMax()
  409. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  410. self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
  411. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  412. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  413. def construct(self, idx, end, x):
  414. out = self.zero
  415. while idx < end:
  416. out = out + self.param + self.weight
  417. idx = idx + 1
  418. return out + self.param
  419. class GradNet(nn.Cell):
  420. def __init__(self, net):
  421. super(GradNet, self).__init__()
  422. self.net = net
  423. self.weights = ParameterTuple(net.trainable_params())
  424. def construct(self, a, b, c):
  425. return C.grad_by_list(self.net, self.weights)(a, b, c)
  426. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  427. while_net = MyWhileNet()
  428. net = GradNet(while_net)
  429. idx = Tensor(np.array(0), dtype=ms.int32)
  430. end = Tensor(np.array(3), dtype=ms.int32)
  431. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  432. net(idx, end, x)
  433. def test_while_with_param_basic_grad_three():
  434. class MyWhileNet(nn.Cell):
  435. def __init__(self):
  436. super().__init__()
  437. self.max = P.ReduceMax()
  438. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  439. self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
  440. self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key")
  441. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  442. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  443. def construct(self, idx, end, x):
  444. out = self.zero
  445. while idx < end:
  446. out = out + self.param + self.weight + self.key
  447. idx = idx + 1
  448. return out + self.param
  449. class GradNet(nn.Cell):
  450. def __init__(self, net):
  451. super(GradNet, self).__init__()
  452. self.net = net
  453. self.weights = ParameterTuple(net.trainable_params())
  454. def construct(self, a, b, c):
  455. return C.grad_by_list(self.net, self.weights)(a, b, c)
  456. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  457. while_net = MyWhileNet()
  458. net = GradNet(while_net)
  459. idx = Tensor(np.array(0), dtype=ms.int32)
  460. end = Tensor(np.array(3), dtype=ms.int32)
  461. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  462. net(idx, end, x)
  463. def test_while_if_with_param_grad():
  464. class MyWhileNet(nn.Cell):
  465. def __init__(self):
  466. super().__init__()
  467. self.max = P.ReduceMax()
  468. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  469. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  470. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  471. def construct(self, idx, end, x):
  472. out = self.zero
  473. while idx < end:
  474. if self.max(out) < self.max(x):
  475. out = out + self.param * 2
  476. else:
  477. out = out + self.param
  478. idx = idx + 1
  479. return out + self.param
  480. class GradNet(nn.Cell):
  481. def __init__(self, net):
  482. super(GradNet, self).__init__()
  483. self.net = net
  484. self.weights = ParameterTuple(net.trainable_params())
  485. def construct(self, a, b, c):
  486. return C.grad_by_list(self.net, self.weights)(a, b, c)
  487. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  488. while_net = MyWhileNet()
  489. net = GradNet(while_net)
  490. idx = Tensor(np.array(0), dtype=ms.int32)
  491. end = Tensor(np.array(3), dtype=ms.int32)
  492. x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
  493. net(idx, end, x)
  494. def test_while_with_param_grad_not_enter_while():
  495. class MyWhileNet(nn.Cell):
  496. def __init__(self):
  497. super().__init__()
  498. self.max = P.ReduceMax()
  499. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  500. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  501. def construct(self, idx, end, x):
  502. out = self.zero
  503. while idx < end:
  504. out = out + self.param * 3
  505. idx = idx + 1
  506. return out + self.param
  507. class GradNet(nn.Cell):
  508. def __init__(self, net):
  509. super(GradNet, self).__init__()
  510. self.net = net
  511. self.weights = ParameterTuple(net.trainable_params())
  512. def construct(self, a, b, c):
  513. return C.grad_by_list(self.net, self.weights)(a, b, c)
  514. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  515. while_net = MyWhileNet()
  516. net = GradNet(while_net)
  517. idx = Tensor(np.array(3), dtype=ms.int32)
  518. end = Tensor(np.array(0), dtype=ms.int32)
  519. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  520. net(idx, end, x)
  521. def test_with_param_if_by_if_forward():
  522. class MyIfByIfNet(nn.Cell):
  523. def __init__(self):
  524. super().__init__()
  525. self.max = P.ReduceMax()
  526. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  527. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  528. def construct(self, a, b, x):
  529. out = self.zero
  530. if a < b:
  531. out = out + x + self.param
  532. else:
  533. out = out + x
  534. if a == b:
  535. out = out + x*3 + self.param
  536. else:
  537. out = out + x*2
  538. return out
  539. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  540. if_net = MyIfByIfNet()
  541. net = if_net
  542. idx = Tensor(np.array(0), dtype=ms.int32)
  543. end = Tensor(np.array(4), dtype=ms.int32)
  544. x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
  545. net(idx, end, x)
  546. def test_with_param_if_by_if_grad_inputs():
  547. class MyIfByIfNet(nn.Cell):
  548. def __init__(self):
  549. super().__init__()
  550. self.max = P.ReduceMax()
  551. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  552. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  553. def construct(self, a, b, x):
  554. out = self.zero
  555. if a < b:
  556. out = out + x + self.param * 4
  557. if a == b:
  558. out = out + x*3 + self.param * 3
  559. return out
  560. class GradNet(nn.Cell):
  561. def __init__(self, net):
  562. super(GradNet, self).__init__()
  563. self.net = net
  564. def construct(self, *inputs):
  565. return C.grad_all(self.net)(*inputs)
  566. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  567. if_net = MyIfByIfNet()
  568. net = GradNet(if_net)
  569. idx = Tensor(np.array(0), dtype=ms.int32)
  570. end = Tensor(np.array(0), dtype=ms.int32)
  571. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  572. net(idx, end, x)
  573. def test_with_param_if_by_if_grad_parameter():
  574. class MyIfByIfNet(nn.Cell):
  575. def __init__(self):
  576. super().__init__()
  577. self.max = P.ReduceMax()
  578. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  579. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  580. def construct(self, a, b, x):
  581. out = self.zero
  582. if a < b:
  583. out = out + x + self.param * 2
  584. if a == b:
  585. out = out + x*3 + self.param
  586. return out
  587. class GradNet(nn.Cell):
  588. def __init__(self, net):
  589. super(GradNet, self).__init__()
  590. self.net = net
  591. self.weights = ParameterTuple(net.trainable_params())
  592. def construct(self, *inputs):
  593. return C.grad_by_list(self.net, self.weights)(*inputs)
  594. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  595. if_net = MyIfByIfNet()
  596. net = GradNet(if_net)
  597. idx = Tensor(np.array(0), dtype=ms.int32)
  598. end = Tensor(np.array(2), dtype=ms.int32)
  599. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  600. net(idx, end, x)
  601. def test_with_param_if_by_if_grad_param_excute_null():
  602. class MyIfByIfNet(nn.Cell):
  603. def __init__(self):
  604. super().__init__()
  605. self.max = P.ReduceMax()
  606. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  607. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  608. def construct(self, a, b, x):
  609. out = self.zero
  610. if a < b:
  611. out = out + x + self.param * 2
  612. return out
  613. class GradNet(nn.Cell):
  614. def __init__(self, net):
  615. super(GradNet, self).__init__()
  616. self.net = net
  617. self.weights = ParameterTuple(net.trainable_params())
  618. def construct(self, *inputs):
  619. return C.grad_by_list(self.net, self.weights)(*inputs)
  620. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  621. if_net = MyIfByIfNet()
  622. net = GradNet(if_net)
  623. idx = Tensor(np.array(4), dtype=ms.int32)
  624. end = Tensor(np.array(0), dtype=ms.int32)
  625. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  626. net(idx, end, x)
  627. def test_if_by_if_return_inside_grad():
  628. class MyIfByIfNet(nn.Cell):
  629. def __init__(self):
  630. super().__init__()
  631. self.max = P.ReduceMax()
  632. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  633. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  634. def construct(self, a, b, x):
  635. out = self.zero
  636. if a < b:
  637. return out + x + self.param
  638. if a == b:
  639. return out + self.param * 2
  640. return out + self.param * 3
  641. class GradNet(nn.Cell):
  642. def __init__(self, net):
  643. super(GradNet, self).__init__()
  644. self.net = net
  645. self.weights = ParameterTuple(net.trainable_params())
  646. def construct(self, *inputs):
  647. return C.grad_by_list(self.net, self.weights)(*inputs)
  648. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  649. if_net = MyIfByIfNet()
  650. net = GradNet(if_net)
  651. idx = Tensor(np.array(1), dtype=ms.int32)
  652. end = Tensor(np.array(0), dtype=ms.int32)
  653. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  654. net(idx, end, x)
  655. def test_if_by_if_forward():
  656. class MyIfByIfNet(nn.Cell):
  657. def __init__(self):
  658. super().__init__()
  659. self.add = P.TensorAdd()
  660. self.sub = P.Sub()
  661. self.mul = P.Mul()
  662. self.div = P.RealDiv()
  663. def construct(self, a, b, x):
  664. if a < b:
  665. a = self.add(a, b)
  666. else:
  667. a = self.sub(a, b)
  668. if a == x:
  669. a = self.mul(a, b)
  670. else:
  671. a = self.div(a, b)
  672. if b == x:
  673. b = self.add(a, b)
  674. else:
  675. b = self.add(a, x)
  676. a = a * b
  677. out = a + b + x
  678. return out
  679. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
  680. if_net = MyIfByIfNet()
  681. net = if_net
  682. idx = Tensor(np.array(2), dtype=ms.float32)
  683. end = Tensor(np.array(3), dtype=ms.float32)
  684. x = Tensor(np.array(4), dtype=ms.float32)
  685. net(idx, end, x)