You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cont_grad.py 42 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test control ops """
  16. import numpy as np
  17. import pytest
  18. from mindspore import dtype as ms
  19. from mindspore import Tensor
  20. from mindspore import context
  21. from mindspore import nn
  22. from mindspore.common.parameter import Parameter, ParameterTuple
  23. from mindspore.ops import composite as C
  24. from mindspore.ops import operations as P
  25. # from tests.vm_impl.math_ops_vm_impl import *
  26. # from tests.vm_impl.vm_interface import *
  27. # from tests.vm_impl import *
  28. # context.set_context(save_graphs=True)
  29. grad_by_list = C.GradOperation(get_by_list=True)
  30. grad_all = C.GradOperation(get_all=True)
  31. def test_while_forward():
  32. class MyWhileNet(nn.Cell):
  33. def __init__(self):
  34. super().__init__()
  35. self.max = P.ReduceMax()
  36. def construct(self, idx, end, x):
  37. while idx < end:
  38. part = x[idx, :, :]
  39. max_num = self.max(part)
  40. x[idx, :, 0:2] = max_num
  41. idx = idx + 1
  42. return x
  43. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  44. net = MyWhileNet()
  45. idx = Tensor(np.array(0), dtype=ms.int32)
  46. end = Tensor(np.array(2), dtype=ms.int32)
  47. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  48. net(idx, end, x)
  49. def test_while_grad():
  50. class MyWhileNet(nn.Cell):
  51. def __init__(self):
  52. super().__init__()
  53. self.max = P.ReduceMax()
  54. def construct(self, idx, end, x):
  55. while idx < end:
  56. part = x[idx, :, :]
  57. max_num = self.max(part)
  58. x[idx, :, 0:2] = max_num
  59. idx = idx + 1
  60. return x
  61. class GradNet(nn.Cell):
  62. def __init__(self, net):
  63. super(GradNet, self).__init__()
  64. self.net = net
  65. def construct(self, *inputs):
  66. return grad_all(self.net)(*inputs)
  67. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  68. while_net = MyWhileNet()
  69. net = GradNet(while_net)
  70. idx = Tensor(np.array(0), dtype=ms.int32)
  71. end = Tensor(np.array(2), dtype=ms.int32)
  72. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  73. net(idx, end, x)
  74. def test_while_with_param_forward():
  75. class MyWhileNet(nn.Cell):
  76. def __init__(self):
  77. super().__init__()
  78. self.max = P.ReduceMax()
  79. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  80. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  81. def construct(self, idx, end, x):
  82. out = self.zero
  83. while idx < end:
  84. part = x[idx, :, :]
  85. max_num = self.max(part)
  86. x[idx, :, 0:2] = max_num
  87. out = out + x + self.param
  88. idx = idx + 1
  89. return out
  90. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  91. net = MyWhileNet()
  92. idx = Tensor(np.array(0), dtype=ms.int32)
  93. end = Tensor(np.array(2), dtype=ms.int32)
  94. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  95. net(idx, end, x)
  96. def test_while_endless_case():
  97. """endless case when optmization"""
  98. class MyWhileNet(nn.Cell):
  99. def __init__(self):
  100. super().__init__()
  101. self.max = P.ReduceMax()
  102. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  103. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  104. def construct(self, idx, end, x):
  105. out = self.zero
  106. while idx < end:
  107. part = x[idx, :, :]
  108. out = out + part
  109. idx = idx + 1
  110. return out
  111. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  112. net = MyWhileNet()
  113. idx = Tensor(np.array(0), dtype=ms.int32)
  114. end = Tensor(np.array(2), dtype=ms.int32)
  115. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  116. net(idx, end, x)
  117. def test_while_with_param_grad():
  118. class MyWhileNet(nn.Cell):
  119. def __init__(self):
  120. super().__init__()
  121. self.max = P.ReduceMax()
  122. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  123. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  124. def construct(self, idx, end, x):
  125. out = self.zero
  126. while idx < end:
  127. part = x[idx, :, :]
  128. max_num = self.max(part)
  129. x[idx, :, 0:2] = max_num
  130. out = out + x + self.param
  131. idx = idx + 1
  132. return out
  133. class GradNet(nn.Cell):
  134. def __init__(self, net):
  135. super(GradNet, self).__init__()
  136. self.net = net
  137. self.weights = ParameterTuple(net.trainable_params())
  138. def construct(self, a, b, c):
  139. return grad_by_list(self.net, self.weights)(a, b, c)
  140. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  141. while_net = MyWhileNet()
  142. net = GradNet(while_net)
  143. idx = Tensor(np.array(0), dtype=ms.int32)
  144. end = Tensor(np.array(2), dtype=ms.int32)
  145. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  146. net(idx, end, x)
  147. def test_while_with_param_forward_with_const_branch():
  148. class MyWhileNet(nn.Cell):
  149. def __init__(self):
  150. super().__init__()
  151. self.max = P.ReduceMax()
  152. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  153. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  154. self.reduce = P.ReduceSum()
  155. def construct(self, idx, end, x):
  156. out = self.zero
  157. while idx < end:
  158. if 2 > 1:
  159. out = out + self.param
  160. else:
  161. out = out + idx + self.param
  162. idx = idx + 1
  163. return out
  164. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  165. while_net = MyWhileNet()
  166. net = while_net
  167. idx = Tensor(np.array(0), dtype=ms.int32)
  168. end = Tensor(np.array(4), dtype=ms.int32)
  169. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  170. net(idx, end, x)
  171. def test_while_opt_endless():
  172. """endless during optimization case"""
  173. class MyWhileNet(nn.Cell):
  174. def __init__(self):
  175. super().__init__()
  176. self.max = P.ReduceMax()
  177. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  178. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  179. self.reduce = P.ReduceSum()
  180. self.addn = P.AddN()
  181. def construct(self, idx, end, x):
  182. addn1 = self.addn((x, x, x))
  183. out = addn1
  184. while idx < end:
  185. out = self.addn((out, addn1))
  186. idx = idx + 1
  187. out = self.addn((out, x))
  188. return out
  189. class GradNet(nn.Cell):
  190. def __init__(self, net):
  191. super(GradNet, self).__init__()
  192. self.net = net
  193. def construct(self, *inputs):
  194. return grad_all(self.net)(*inputs)
  195. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  196. while_net = MyWhileNet()
  197. net = GradNet(while_net)
  198. idx = Tensor(np.array(0), dtype=ms.int32)
  199. end = Tensor(np.array(4), dtype=ms.int32)
  200. x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32)
  201. net(idx, end, x)
  202. def test_no_while_call():
  203. class MyWhileNet(nn.Cell):
  204. def __init__(self):
  205. super().__init__()
  206. self.max = P.ReduceMax()
  207. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  208. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  209. self.reduce = P.ReduceSum()
  210. def construct(self, idx, end, x):
  211. out = self.zero
  212. if 2 > 1:
  213. out = out + self.param
  214. else:
  215. out = out + idx + self.param
  216. return out
  217. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  218. while_net = MyWhileNet()
  219. net = while_net
  220. idx = Tensor(np.array(0), dtype=ms.int32)
  221. end = Tensor(np.array(4), dtype=ms.int32)
  222. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  223. net(idx, end, x)
  224. def test_while_with_param_grad_with_const_branch():
  225. class MyWhileNet(nn.Cell):
  226. def __init__(self):
  227. super().__init__()
  228. self.max = P.ReduceMax()
  229. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  230. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  231. self.reduce = P.ReduceSum()
  232. def construct(self, idx, end, x):
  233. out = self.zero
  234. while idx < end:
  235. if 2 > 1:
  236. out = out + self.param
  237. else:
  238. out = out + idx + self.param
  239. idx = idx + 1
  240. return out
  241. class GradNet(nn.Cell):
  242. def __init__(self, net):
  243. super(GradNet, self).__init__()
  244. self.net = net
  245. self.weights = ParameterTuple(net.trainable_params())
  246. def construct(self, a, b, c):
  247. return grad_by_list(self.net, self.weights)(a, b, c)
  248. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  249. while_net = MyWhileNet()
  250. net = GradNet(while_net)
  251. idx = Tensor(np.array(0), dtype=ms.int32)
  252. end = Tensor(np.array(4), dtype=ms.int32)
  253. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  254. net(idx, end, x)
  255. def test_for_while_with_param_grad_with_const_branch():
  256. class MyWhileNet(nn.Cell):
  257. def __init__(self):
  258. super().__init__()
  259. self.max = P.ReduceMax()
  260. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  261. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  262. self.reduce = P.ReduceSum()
  263. self.start = Tensor(np.array(0), dtype=ms.int32)
  264. def construct(self, idx, end, x):
  265. out = self.zero
  266. for _ in range(0, 2):
  267. idx = self.start
  268. while idx < end:
  269. if 2 > 1:
  270. out = out + self.param
  271. else:
  272. out = out + idx + self.param
  273. idx = idx + 1
  274. return out
  275. class GradNet(nn.Cell):
  276. def __init__(self, net):
  277. super(GradNet, self).__init__()
  278. self.net = net
  279. self.weights = ParameterTuple(net.trainable_params())
  280. def construct(self, a, b, c):
  281. return grad_by_list(self.net, self.weights)(a, b, c)
  282. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  283. while_net = MyWhileNet()
  284. net = GradNet(while_net)
  285. idx = Tensor(np.array(0), dtype=ms.int32)
  286. end = Tensor(np.array(4), dtype=ms.int32)
  287. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  288. net(idx, end, x)
  289. def test_for_while_with_param_grad_basic():
  290. class MyWhileNet(nn.Cell):
  291. def __init__(self):
  292. super().__init__()
  293. self.max = P.ReduceMax()
  294. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  295. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  296. self.reduce = P.ReduceSum()
  297. self.start = Tensor(np.array(0), dtype=ms.int32)
  298. def construct(self, idx, end, x):
  299. out = self.zero
  300. for _ in range(0, 2):
  301. idx = self.start
  302. while idx < end:
  303. out = out + self.param
  304. idx = idx + 1
  305. return out
  306. class GradNet(nn.Cell):
  307. def __init__(self, net):
  308. super(GradNet, self).__init__()
  309. self.net = net
  310. self.weights = ParameterTuple(net.trainable_params())
  311. def construct(self, a, b, c):
  312. return grad_by_list(self.net, self.weights)(a, b, c)
  313. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  314. while_net = MyWhileNet()
  315. net = GradNet(while_net)
  316. idx = Tensor(np.array(0), dtype=ms.int32)
  317. end = Tensor(np.array(4), dtype=ms.int32)
  318. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  319. net(idx, end, x)
  320. def test_for_while_with_param_grad_normal():
  321. class MyWhileNet(nn.Cell):
  322. def __init__(self):
  323. super().__init__()
  324. self.max = P.ReduceMax()
  325. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  326. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  327. self.reduce = P.ReduceSum()
  328. self.start = Tensor(np.array(0), dtype=ms.int32)
  329. def construct(self, idx, end, x):
  330. out = x
  331. for _ in range(0, 2):
  332. idx = self.start
  333. while idx < end:
  334. out = out + self.param
  335. idx = idx + 1
  336. return out
  337. class GradNet(nn.Cell):
  338. def __init__(self, net):
  339. super(GradNet, self).__init__()
  340. self.net = net
  341. self.weights = ParameterTuple(net.trainable_params())
  342. def construct(self, a, b, c):
  343. return grad_by_list(self.net, self.weights)(a, b, c)
  344. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  345. while_net = MyWhileNet()
  346. net = GradNet(while_net)
  347. idx = Tensor(np.array(0), dtype=ms.int32)
  348. end = Tensor(np.array(4), dtype=ms.int32)
  349. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  350. net(idx, end, x)
  351. def test_while_with_param_basic_grad():
  352. class MyWhileNet(nn.Cell):
  353. def __init__(self):
  354. super().__init__()
  355. self.max = P.ReduceMax()
  356. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  357. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  358. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  359. def construct(self, idx, end, x):
  360. out = self.zero
  361. while idx < end:
  362. out = out + self.param
  363. idx = idx + 1
  364. return out + self.param
  365. class GradNet(nn.Cell):
  366. def __init__(self, net):
  367. super(GradNet, self).__init__()
  368. self.net = net
  369. self.weights = ParameterTuple(net.trainable_params())
  370. def construct(self, a, b, c):
  371. return grad_by_list(self.net, self.weights)(a, b, c)
  372. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  373. while_net = MyWhileNet()
  374. net = GradNet(while_net)
  375. idx = Tensor(np.array(0), dtype=ms.int32)
  376. end = Tensor(np.array(3), dtype=ms.int32)
  377. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  378. net(idx, end, x)
  379. def test_while_with_param_basic_grad_mul():
  380. class MyWhileNet(nn.Cell):
  381. def __init__(self):
  382. super().__init__()
  383. self.max = P.ReduceMax()
  384. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  385. self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32)
  386. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  387. def construct(self, idx, end, x):
  388. out = self.zero
  389. while idx < end:
  390. out = out * self.param
  391. idx = idx + 1
  392. return out + self.param
  393. class GradNet(nn.Cell):
  394. def __init__(self, net):
  395. super(GradNet, self).__init__()
  396. self.net = net
  397. self.weights = ParameterTuple(net.trainable_params())
  398. def construct(self, a, b, c):
  399. return grad_by_list(self.net, self.weights)(a, b, c)
  400. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  401. while_net = MyWhileNet()
  402. net = GradNet(while_net)
  403. idx = Tensor(np.array(0), dtype=ms.int32)
  404. end = Tensor(np.array(3), dtype=ms.int32)
  405. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  406. net(idx, end, x)
  407. def test_while_with_param_basic_grad_two():
  408. class MyWhileNet(nn.Cell):
  409. def __init__(self):
  410. super().__init__()
  411. self.max = P.ReduceMax()
  412. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  413. self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
  414. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  415. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  416. def construct(self, idx, end, x):
  417. out = self.zero
  418. while idx < end:
  419. out = out + self.param + self.weight
  420. idx = idx + 1
  421. return out + self.param
  422. class GradNet(nn.Cell):
  423. def __init__(self, net):
  424. super(GradNet, self).__init__()
  425. self.net = net
  426. self.weights = ParameterTuple(net.trainable_params())
  427. def construct(self, a, b, c):
  428. return grad_by_list(self.net, self.weights)(a, b, c)
  429. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  430. while_net = MyWhileNet()
  431. net = GradNet(while_net)
  432. idx = Tensor(np.array(0), dtype=ms.int32)
  433. end = Tensor(np.array(3), dtype=ms.int32)
  434. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  435. net(idx, end, x)
  436. def test_while_with_param_basic_grad_three():
  437. class MyWhileNet(nn.Cell):
  438. def __init__(self):
  439. super().__init__()
  440. self.max = P.ReduceMax()
  441. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  442. self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
  443. self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key")
  444. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  445. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  446. def construct(self, idx, end, x):
  447. out = self.zero
  448. while idx < end:
  449. out = out + self.param + self.weight + self.key
  450. idx = idx + 1
  451. return out + self.param
  452. class GradNet(nn.Cell):
  453. def __init__(self, net):
  454. super(GradNet, self).__init__()
  455. self.net = net
  456. self.weights = ParameterTuple(net.trainable_params())
  457. def construct(self, a, b, c):
  458. return grad_by_list(self.net, self.weights)(a, b, c)
  459. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  460. while_net = MyWhileNet()
  461. net = GradNet(while_net)
  462. idx = Tensor(np.array(0), dtype=ms.int32)
  463. end = Tensor(np.array(3), dtype=ms.int32)
  464. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  465. net(idx, end, x)
  466. def test_while_if_with_param_grad():
  467. class MyWhileNet(nn.Cell):
  468. def __init__(self):
  469. super().__init__()
  470. self.max = P.ReduceMax()
  471. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  472. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  473. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  474. def construct(self, idx, end, x):
  475. out = self.zero
  476. while idx < end:
  477. if self.max(out) < self.max(x):
  478. out = out + self.param * 2
  479. else:
  480. out = out + self.param
  481. idx = idx + 1
  482. return out + self.param
  483. class GradNet(nn.Cell):
  484. def __init__(self, net):
  485. super(GradNet, self).__init__()
  486. self.net = net
  487. self.weights = ParameterTuple(net.trainable_params())
  488. def construct(self, a, b, c):
  489. return grad_by_list(self.net, self.weights)(a, b, c)
  490. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  491. while_net = MyWhileNet()
  492. net = GradNet(while_net)
  493. idx = Tensor(np.array(0), dtype=ms.int32)
  494. end = Tensor(np.array(3), dtype=ms.int32)
  495. x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
  496. net(idx, end, x)
  497. def test_while_with_param_grad_not_enter_while():
  498. class MyWhileNet(nn.Cell):
  499. def __init__(self):
  500. super().__init__()
  501. self.max = P.ReduceMax()
  502. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  503. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  504. def construct(self, idx, end, x):
  505. out = self.zero
  506. while idx < end:
  507. out = out + self.param * 3
  508. idx = idx + 1
  509. return out + self.param
  510. class GradNet(nn.Cell):
  511. def __init__(self, net):
  512. super(GradNet, self).__init__()
  513. self.net = net
  514. self.weights = ParameterTuple(net.trainable_params())
  515. def construct(self, a, b, c):
  516. return grad_by_list(self.net, self.weights)(a, b, c)
  517. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  518. while_net = MyWhileNet()
  519. net = GradNet(while_net)
  520. idx = Tensor(np.array(3), dtype=ms.int32)
  521. end = Tensor(np.array(0), dtype=ms.int32)
  522. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  523. net(idx, end, x)
  524. def test_with_param_if_by_if_forward():
  525. class MyIfByIfNet(nn.Cell):
  526. def __init__(self):
  527. super().__init__()
  528. self.max = P.ReduceMax()
  529. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  530. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  531. def construct(self, a, b, x):
  532. out = self.zero
  533. if a < b:
  534. out = out + x + self.param
  535. else:
  536. out = out + x
  537. if a == b:
  538. out = out + x*3 + self.param
  539. else:
  540. out = out + x*2
  541. return out
  542. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  543. if_net = MyIfByIfNet()
  544. net = if_net
  545. idx = Tensor(np.array(0), dtype=ms.int32)
  546. end = Tensor(np.array(4), dtype=ms.int32)
  547. x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
  548. net(idx, end, x)
  549. def test_with_param_if_by_if_grad_inputs():
  550. class MyIfByIfNet(nn.Cell):
  551. def __init__(self):
  552. super().__init__()
  553. self.max = P.ReduceMax()
  554. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  555. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  556. def construct(self, a, b, x):
  557. out = self.zero
  558. if a < b:
  559. out = out + x + self.param * 4
  560. if a == b:
  561. out = out + x*3 + self.param * 3
  562. return out
  563. class GradNet(nn.Cell):
  564. def __init__(self, net):
  565. super(GradNet, self).__init__()
  566. self.net = net
  567. def construct(self, *inputs):
  568. return grad_all(self.net)(*inputs)
  569. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  570. if_net = MyIfByIfNet()
  571. net = GradNet(if_net)
  572. idx = Tensor(np.array(0), dtype=ms.int32)
  573. end = Tensor(np.array(0), dtype=ms.int32)
  574. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  575. net(idx, end, x)
  576. def test_with_param_if_by_if_grad_parameter():
  577. class MyIfByIfNet(nn.Cell):
  578. def __init__(self):
  579. super().__init__()
  580. self.max = P.ReduceMax()
  581. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  582. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  583. def construct(self, a, b, x):
  584. out = self.zero
  585. if a < b:
  586. out = out + x + self.param * 2
  587. if a == b:
  588. out = out + x*3 + self.param
  589. return out
  590. class GradNet(nn.Cell):
  591. def __init__(self, net):
  592. super(GradNet, self).__init__()
  593. self.net = net
  594. self.weights = ParameterTuple(net.trainable_params())
  595. def construct(self, *inputs):
  596. return grad_by_list(self.net, self.weights)(*inputs)
  597. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  598. if_net = MyIfByIfNet()
  599. net = GradNet(if_net)
  600. idx = Tensor(np.array(0), dtype=ms.int32)
  601. end = Tensor(np.array(2), dtype=ms.int32)
  602. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  603. net(idx, end, x)
  604. def test_with_param_if_by_if_grad_param_excute_null():
  605. class MyIfByIfNet(nn.Cell):
  606. def __init__(self):
  607. super().__init__()
  608. self.max = P.ReduceMax()
  609. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  610. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  611. def construct(self, a, b, x):
  612. out = self.zero
  613. if a < b:
  614. out = out + x + self.param * 2
  615. return out
  616. class GradNet(nn.Cell):
  617. def __init__(self, net):
  618. super(GradNet, self).__init__()
  619. self.net = net
  620. self.weights = ParameterTuple(net.trainable_params())
  621. def construct(self, *inputs):
  622. return grad_by_list(self.net, self.weights)(*inputs)
  623. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  624. if_net = MyIfByIfNet()
  625. net = GradNet(if_net)
  626. idx = Tensor(np.array(4), dtype=ms.int32)
  627. end = Tensor(np.array(0), dtype=ms.int32)
  628. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  629. net(idx, end, x)
  630. def test_if_by_if_return_inside_grad():
  631. class MyIfByIfNet(nn.Cell):
  632. def __init__(self):
  633. super().__init__()
  634. self.max = P.ReduceMax()
  635. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  636. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  637. def construct(self, a, b, x):
  638. out = self.zero
  639. if a < b:
  640. return out + x + self.param
  641. if a == b:
  642. return out + self.param * 2
  643. return out + self.param * 3
  644. class GradNet(nn.Cell):
  645. def __init__(self, net):
  646. super(GradNet, self).__init__()
  647. self.net = net
  648. self.weights = ParameterTuple(net.trainable_params())
  649. def construct(self, *inputs):
  650. return grad_by_list(self.net, self.weights)(*inputs)
  651. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  652. if_net = MyIfByIfNet()
  653. net = GradNet(if_net)
  654. idx = Tensor(np.array(1), dtype=ms.int32)
  655. end = Tensor(np.array(0), dtype=ms.int32)
  656. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  657. net(idx, end, x)
  658. def test_if_by_if_forward():
  659. class MyIfByIfNet(nn.Cell):
  660. def __init__(self):
  661. super().__init__()
  662. self.add = P.TensorAdd()
  663. self.sub = P.Sub()
  664. self.mul = P.Mul()
  665. self.div = P.RealDiv()
  666. def construct(self, a, b, x):
  667. if a < b:
  668. a = self.add(a, b)
  669. else:
  670. a = self.sub(a, b)
  671. if a == x:
  672. a = self.mul(a, b)
  673. else:
  674. a = self.div(a, b)
  675. if b == x:
  676. b = self.add(a, b)
  677. else:
  678. b = self.add(a, x)
  679. a = a * b
  680. out = a + b + x
  681. return out
  682. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  683. if_net = MyIfByIfNet()
  684. net = if_net
  685. idx = Tensor(np.array(2), dtype=ms.float32)
  686. end = Tensor(np.array(3), dtype=ms.float32)
  687. x = Tensor(np.array(4), dtype=ms.float32)
  688. net(idx, end, x)
  689. def test_if_by_if_forward_control_tuple_switch():
  690. """tuple_get from swtich op will generate new switch inside to eliminate tuple_get"""
  691. class Branch3Net(nn.Cell):
  692. def __init__(self):
  693. super().__init__()
  694. self.add = P.TensorAdd()
  695. self.sub = P.Sub()
  696. self.mul = P.Mul()
  697. self.div = P.RealDiv()
  698. def construct(self, a, b, x):
  699. if b == x:
  700. b = self.add(a, b)
  701. else:
  702. b = self.add(a, x)
  703. return a, b, x
  704. class Branch2Net(nn.Cell):
  705. def __init__(self):
  706. super().__init__()
  707. self.add = P.TensorAdd()
  708. self.sub = P.Sub()
  709. self.mul = P.Mul()
  710. self.div = P.RealDiv()
  711. self.net = Branch3Net()
  712. def construct(self, a, b, x):
  713. if a == x:
  714. a = self.mul(a, b)
  715. else:
  716. a = self.div(a, b)
  717. return self.net(a, b, x)
  718. class MyIfByIfNet(nn.Cell):
  719. def __init__(self):
  720. super().__init__()
  721. self.add = P.TensorAdd()
  722. self.sub = P.Sub()
  723. self.mul = P.Mul()
  724. self.div = P.RealDiv()
  725. self.net = Branch2Net()
  726. def construct(self, a, b, x):
  727. if a < b:
  728. a = self.add(a, b)
  729. else:
  730. a = self.sub(a, b)
  731. a, b, x = self.net(a, b, x)
  732. a = a * b
  733. out = a + b + x
  734. return out
  735. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  736. if_net = MyIfByIfNet()
  737. net = if_net
  738. idx = Tensor(np.array(2), dtype=ms.float32)
  739. end = Tensor(np.array(3), dtype=ms.float32)
  740. x = Tensor(np.array(0), dtype=ms.float32)
  741. net(idx, end, x)
  742. def test_if_by_if_forward_control_inside_net():
  743. class Branch3Net(nn.Cell):
  744. def __init__(self):
  745. super().__init__()
  746. self.add = P.TensorAdd()
  747. self.sub = P.Sub()
  748. self.mul = P.Mul()
  749. self.div = P.RealDiv()
  750. def construct(self, a, b, x):
  751. if b == x:
  752. b = self.add(a, b)
  753. else:
  754. b = self.add(a, x)
  755. a = a * b
  756. out = a + b + x
  757. return out
  758. class Branch2Net(nn.Cell):
  759. def __init__(self):
  760. super().__init__()
  761. self.add = P.TensorAdd()
  762. self.sub = P.Sub()
  763. self.mul = P.Mul()
  764. self.div = P.RealDiv()
  765. self.net = Branch3Net()
  766. def construct(self, a, b, x):
  767. if a == x:
  768. a = self.mul(a, b)
  769. else:
  770. a = self.div(a, b)
  771. return self.net(a, b, x)
  772. class MyIfByIfNet(nn.Cell):
  773. def __init__(self):
  774. super().__init__()
  775. self.add = P.TensorAdd()
  776. self.sub = P.Sub()
  777. self.mul = P.Mul()
  778. self.div = P.RealDiv()
  779. self.net = Branch2Net()
  780. def construct(self, a, b, x):
  781. if a < b:
  782. a = self.add(a, b)
  783. else:
  784. a = self.sub(a, b)
  785. out = self.net(a, b, x)
  786. return out
  787. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  788. if_net = MyIfByIfNet()
  789. net = if_net
  790. idx = Tensor(np.array(2), dtype=ms.float32)
  791. end = Tensor(np.array(3), dtype=ms.float32)
  792. x = Tensor(np.array(0), dtype=ms.float32)
  793. net(idx, end, x)
  794. def test_if_by_if_forward_use_namespace():
  795. class MyIfByIfNet(nn.Cell):
  796. def __init__(self):
  797. super().__init__()
  798. self.add = P.TensorAdd()
  799. self.sub = P.Sub()
  800. self.mul = P.Mul()
  801. self.div = P.RealDiv()
  802. def construct(self, a, b, x):
  803. if a < b:
  804. a = P.TensorAdd()(a, b)
  805. else:
  806. a = P.Sub()(a, b)
  807. if a == x:
  808. a = P.Mul()(a, b)
  809. else:
  810. a = P.RealDiv()(a, b)
  811. if b == x:
  812. b = P.TensorAdd()(a, b)
  813. else:
  814. b = P.TensorAdd()(a, x)
  815. a = a * b
  816. out = a + b + x
  817. return out
  818. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  819. if_net = MyIfByIfNet()
  820. net = if_net
  821. idx = Tensor(np.array(2), dtype=ms.float32)
  822. end = Tensor(np.array(3), dtype=ms.float32)
  823. x = Tensor(np.array(0), dtype=ms.float32)
  824. net(idx, end, x)
  825. def test_if_by_if_forward_use_global_op():
  826. class MyIfByIfNet(nn.Cell):
  827. def __init__(self):
  828. super().__init__()
  829. self.add = P.TensorAdd()
  830. self.sub = P.Sub()
  831. self.mul = P.Mul()
  832. self.div = P.RealDiv()
  833. def construct(self, a, b, x):
  834. add = P.TensorAdd()
  835. sub = P.Sub()
  836. mul = P.Mul()
  837. div = P.RealDiv()
  838. if a < b:
  839. a = add(a, b)
  840. else:
  841. a = sub(a, b)
  842. if a == x:
  843. a = mul(a, b)
  844. else:
  845. a = div(a, b)
  846. if b == x:
  847. b = add(a, b)
  848. else:
  849. b = add(a, x)
  850. a = a * b
  851. out = a + b + x
  852. return out
  853. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  854. if_net = MyIfByIfNet()
  855. net = if_net
  856. idx = Tensor(np.array(2), dtype=ms.float32)
  857. end = Tensor(np.array(3), dtype=ms.float32)
  858. x = Tensor(np.array(0), dtype=ms.float32)
  859. net(idx, end, x)
  860. def test_for_with_if_by_if_forward():
  861. class MyIfByIfNet(nn.Cell):
  862. def __init__(self):
  863. super().__init__()
  864. self.add = P.TensorAdd()
  865. self.sub = P.Sub()
  866. def construct(self, a, b, x):
  867. for _ in range(0, 4):
  868. if a < b:
  869. a = self.add(a, b)
  870. else:
  871. b = self.sub(b, x)
  872. a = a * b
  873. out = a + b + x
  874. return out
  875. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  876. if_net = MyIfByIfNet()
  877. net = if_net
  878. idx = Tensor(np.array(2), dtype=ms.float32)
  879. end = Tensor(np.array(3), dtype=ms.float32)
  880. x = Tensor(np.array(0), dtype=ms.float32)
  881. net(idx, end, x)
  882. def test_for_with_if_by_if_forward_namespace():
  883. class MyIfByIfNet(nn.Cell):
  884. def __init__(self):
  885. super().__init__()
  886. self.add = P.TensorAdd()
  887. self.sub = P.Sub()
  888. self.mul = P.Mul()
  889. self.div = P.RealDiv()
  890. def construct(self, a, b, x):
  891. for _ in range(0, 6):
  892. if a < b:
  893. a = P.TensorAdd()(a, b)
  894. else:
  895. b = P.Sub()(b, x)
  896. a = a * b
  897. out = a + b + x
  898. return out
  899. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  900. if_net = MyIfByIfNet()
  901. net = if_net
  902. idx = Tensor(np.array(2), dtype=ms.float32)
  903. end = Tensor(np.array(3), dtype=ms.float32)
  904. x = Tensor(np.array(0), dtype=ms.float32)
  905. net(idx, end, x)
  906. def test_if_by_if_forward_const_branch_inner():
  907. class MyIfByIfNet(nn.Cell):
  908. def __init__(self):
  909. super().__init__()
  910. self.add = P.TensorAdd()
  911. self.sub = P.Sub()
  912. self.mul = P.Mul()
  913. self.div = P.RealDiv()
  914. def construct(self, a, b, x):
  915. add = P.TensorAdd()
  916. sub = P.Sub()
  917. mul = P.Mul()
  918. div = P.RealDiv()
  919. if a < b:
  920. a = add(a, b)
  921. else:
  922. a = sub(a, b)
  923. if 2 > 1:
  924. a = mul(a, b)
  925. else:
  926. a = div(a, b)
  927. if b == x:
  928. b = add(a, b)
  929. else:
  930. b = add(a, x)
  931. a = a * b
  932. out = a + b + x
  933. return out
  934. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  935. if_net = MyIfByIfNet()
  936. net = if_net
  937. idx = Tensor(np.array(2), dtype=ms.float32)
  938. end = Tensor(np.array(3), dtype=ms.float32)
  939. x = Tensor(np.array(0), dtype=ms.float32)
  940. net(idx, end, x)
  941. def test_if_by_if_forward_all_const_branch():
  942. class MyIfByIfNet(nn.Cell):
  943. def __init__(self):
  944. super().__init__()
  945. self.add = P.TensorAdd()
  946. self.sub = P.Sub()
  947. self.mul = P.Mul()
  948. self.div = P.RealDiv()
  949. def construct(self, a, b, x):
  950. add = P.TensorAdd()
  951. sub = P.Sub()
  952. mul = P.Mul()
  953. div = P.RealDiv()
  954. if 2 < 12:
  955. a = add(a, b)
  956. else:
  957. a = sub(a, b)
  958. if 2 > 1:
  959. a = mul(a, b)
  960. else:
  961. a = div(a, b)
  962. if 2 == 1:
  963. b = add(a, b)
  964. else:
  965. b = add(a, x)
  966. a = a * b
  967. out = a + b + x
  968. return out
  969. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  970. if_net = MyIfByIfNet()
  971. net = if_net
  972. idx = Tensor(np.array(2), dtype=ms.float32)
  973. end = Tensor(np.array(3), dtype=ms.float32)
  974. x = Tensor(np.array(0), dtype=ms.float32)
  975. net(idx, end, x)
  976. @pytest.mark.level0
  977. @pytest.mark.platform_x86_cpu
  978. @pytest.mark.env_onecard
  979. def test_if_const_grad():
  980. class MyNet(nn.Cell):
  981. def __init__(self):
  982. super().__init__()
  983. self.add = P.TensorAdd()
  984. def construct(self, *inputs):
  985. out = self.add(*inputs)
  986. return out
  987. class GradNet(nn.Cell):
  988. def __init__(self, net):
  989. super(GradNet, self).__init__()
  990. self.net = net
  991. self.weights = ParameterTuple(net.trainable_params())
  992. def construct(self, *inputs):
  993. a = 1
  994. b = 2
  995. if a > 0:
  996. b = 1
  997. a += b
  998. return grad_by_list(self.net, self.weights)(*inputs)
  999. context.set_context(mode=context.GRAPH_MODE)
  1000. my_net = MyNet()
  1001. net = GradNet(my_net)
  1002. a = Tensor(np.array(0), dtype=ms.int32)
  1003. b = Tensor(np.array(1), dtype=ms.int32)
  1004. net(a, b)
  1005. @pytest.mark.level0
  1006. @pytest.mark.platform_x86_cpu
  1007. @pytest.mark.env_onecard
  1008. def test_if_by_if_const_grad():
  1009. class MyNet(nn.Cell):
  1010. def __init__(self):
  1011. super().__init__()
  1012. self.add = P.TensorAdd()
  1013. def construct(self, *inputs):
  1014. out = self.add(*inputs)
  1015. return out
  1016. class GradNet(nn.Cell):
  1017. def __init__(self, net):
  1018. super(GradNet, self).__init__()
  1019. self.net = net
  1020. self.weights = ParameterTuple(net.trainable_params())
  1021. def construct(self, *inputs):
  1022. a = 1
  1023. b = 2
  1024. if a > 0:
  1025. b = 1
  1026. if a < 0:
  1027. b = 0
  1028. if a == 0:
  1029. b = 3
  1030. a += b
  1031. return grad_by_list(self.net, self.weights)(*inputs)
  1032. context.set_context(mode=context.GRAPH_MODE)
  1033. my_net = MyNet()
  1034. net = GradNet(my_net)
  1035. a = Tensor(np.array(0), dtype=ms.int32)
  1036. b = Tensor(np.array(1), dtype=ms.int32)
  1037. net(a, b)
  1038. @pytest.mark.level0
  1039. @pytest.mark.platform_x86_cpu
  1040. @pytest.mark.env_onecard
  1041. def test_while_const_grad():
  1042. class MyNet(nn.Cell):
  1043. def __init__(self):
  1044. super().__init__()
  1045. self.add = P.TensorAdd()
  1046. def construct(self, *inputs):
  1047. out = self.add(*inputs)
  1048. return out
  1049. class GradNet(nn.Cell):
  1050. def __init__(self, net):
  1051. super(GradNet, self).__init__()
  1052. self.net = net
  1053. self.weights = ParameterTuple(net.trainable_params())
  1054. def construct(self, *inputs):
  1055. a = 1
  1056. while a > 1:
  1057. a = a - 1
  1058. return grad_by_list(self.net, self.weights)(*inputs)
  1059. context.set_context(mode=context.GRAPH_MODE)
  1060. my_net = MyNet()
  1061. net = GradNet(my_net)
  1062. a = Tensor(np.array(0), dtype=ms.int32)
  1063. b = Tensor(np.array(1), dtype=ms.int32)
  1064. net(a, b)
  1065. @pytest.mark.level0
  1066. @pytest.mark.platform_x86_cpu
  1067. @pytest.mark.env_onecard
  1068. def test_if_by_while_const_grad():
  1069. class MyNet(nn.Cell):
  1070. def __init__(self):
  1071. super().__init__()
  1072. self.add = P.TensorAdd()
  1073. def construct(self, *inputs):
  1074. out = self.add(*inputs)
  1075. return out
  1076. class GradNet(nn.Cell):
  1077. def __init__(self, net):
  1078. super(GradNet, self).__init__()
  1079. self.net = net
  1080. self.weights = ParameterTuple(net.trainable_params())
  1081. def construct(self, *inputs):
  1082. a = 1
  1083. b = 2
  1084. if a > 0:
  1085. b = 0
  1086. while a > 1:
  1087. a = a - 1
  1088. a += b
  1089. return grad_by_list(self.net, self.weights)(*inputs)
  1090. context.set_context(mode=context.GRAPH_MODE)
  1091. my_net = MyNet()
  1092. net = GradNet(my_net)
  1093. a = Tensor(np.array(0), dtype=ms.int32)
  1094. b = Tensor(np.array(1), dtype=ms.int32)
  1095. net(a, b)