You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cont_grad.py 56 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test control ops """
  16. import numpy as np
  17. import pytest
  18. from mindspore import dtype as ms
  19. from mindspore import Tensor
  20. from mindspore import context
  21. from mindspore import nn
  22. from mindspore.common.parameter import Parameter, ParameterTuple
  23. from mindspore.ops import composite as C
  24. from mindspore.ops import operations as P
  25. # from tests.vm_impl.math_ops_vm_impl import *
  26. # from tests.vm_impl.vm_interface import *
  27. # from tests.vm_impl import *
  28. # context.set_context(save_graphs=True)
  29. grad_by_list = C.GradOperation(get_by_list=True)
  30. grad_all = C.GradOperation(get_all=True)
  31. def test_while_grad():
  32. class MyWhileNet(nn.Cell):
  33. def __init__(self):
  34. super().__init__()
  35. self.max = P.ReduceMax()
  36. def construct(self, idx, end, x):
  37. while idx < end:
  38. part = x[idx, :, :]
  39. max_num = self.max(part)
  40. x[idx, :, 0:2] = max_num
  41. idx = idx + 1
  42. return x
  43. class GradNet(nn.Cell):
  44. def __init__(self, net):
  45. super(GradNet, self).__init__()
  46. self.net = net
  47. def construct(self, *inputs):
  48. return grad_all(self.net)(*inputs)
  49. # graph mode
  50. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  51. while_net = MyWhileNet()
  52. net = GradNet(while_net)
  53. idx = Tensor(np.array(0), dtype=ms.int32)
  54. end = Tensor(np.array(2), dtype=ms.int32)
  55. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  56. graph_output = net(idx, end, x)
  57. # pynative mode
  58. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  59. pynative_output = net(idx, end, x)
  60. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  61. assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
  62. assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)
  63. @pytest.mark.level0
  64. @pytest.mark.platform_arm_ascend_training
  65. @pytest.mark.platform_x86_ascend_training
  66. @pytest.mark.env_onecard
  67. def test_while_with_const_param_grad():
  68. class MyWhileNet(nn.Cell):
  69. def __init__(self):
  70. super().__init__()
  71. self.mul = P.Mul()
  72. self.add = P.Add()
  73. def construct(self, x, y):
  74. while x < y:
  75. z = self.mul(x, x)
  76. x = self.add(z, 1)
  77. return x
  78. class GradNet(nn.Cell):
  79. def __init__(self, net):
  80. super(GradNet, self).__init__()
  81. self.net = net
  82. def construct(self, *inputs):
  83. return grad_all(self.net)(*inputs)
  84. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  85. while_net = MyWhileNet()
  86. net = GradNet(while_net)
  87. idx = Tensor([1.1], dtype=ms.float32)
  88. end = Tensor([8.0], dtype=ms.float32)
  89. graph_output = net(idx, end)
  90. expect_one = np.array([1.14433983e+02], dtype=np.float32)
  91. expect_two = np.array([0], dtype=np.float32)
  92. assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001)
  93. assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001)
  94. @pytest.mark.level0
  95. @pytest.mark.platform_arm_ascend_training
  96. @pytest.mark.platform_x86_ascend_training
  97. @pytest.mark.env_onecard
  98. def test_while_with_variable_grad():
  99. class MyWhileNet(nn.Cell):
  100. def __init__(self):
  101. super().__init__()
  102. self.mul = P.Mul()
  103. self.add = P.Add()
  104. def construct(self, x, y):
  105. while x < y:
  106. z = self.mul(x, x)
  107. x = self.add(z, y)
  108. return x
  109. class GradNet(nn.Cell):
  110. def __init__(self, net):
  111. super(GradNet, self).__init__()
  112. self.net = net
  113. def construct(self, *inputs):
  114. return grad_all(self.net)(*inputs)
  115. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  116. while_net = MyWhileNet()
  117. net = GradNet(while_net)
  118. idx = Tensor([1.1], dtype=ms.float32)
  119. end = Tensor([8.0], dtype=ms.float32)
  120. graph_output = net(idx, end)
  121. expect_one = np.array([2.20000005e+00], dtype=np.float32)
  122. expect_two = np.array([1.00000000e+00], dtype=np.float32)
  123. assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001)
  124. assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001)
  125. @pytest.mark.level0
  126. @pytest.mark.platform_arm_ascend_training
  127. @pytest.mark.platform_x86_ascend_training
  128. @pytest.mark.env_onecard
  129. def test_while_with_param_forward():
  130. class MyWhileNet(nn.Cell):
  131. def __init__(self):
  132. super().__init__()
  133. self.max = P.ReduceMax()
  134. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  135. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  136. def construct(self, idx, end, x):
  137. out = self.zero
  138. while idx < end:
  139. part = x[idx, :, :]
  140. max_num = self.max(part)
  141. x[idx, :, 0:2] = max_num
  142. out = out + x + self.param
  143. idx = idx + 1
  144. return out
  145. # graph mode
  146. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  147. net = MyWhileNet()
  148. idx = Tensor(np.array(0), dtype=ms.int32)
  149. end = Tensor(np.array(2), dtype=ms.int32)
  150. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  151. graph_output = net(idx, end, x)
  152. expect = np.array([[[6, 8], [10, 12]], [[19, 22], [25, 28]]], dtype=np.int32)
  153. assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
  154. @pytest.mark.level0
  155. @pytest.mark.platform_arm_ascend_training
  156. @pytest.mark.platform_x86_ascend_training
  157. @pytest.mark.env_onecard
  158. def test_while_endless_case():
  159. """endless case when optimization"""
  160. class MyWhileNet(nn.Cell):
  161. def __init__(self):
  162. super().__init__()
  163. self.max = P.ReduceMax()
  164. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  165. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  166. def construct(self, idx, end, x):
  167. out = self.zero
  168. while idx < end:
  169. part = x[idx, :, :]
  170. out = out + part
  171. idx = idx + 1
  172. return out
  173. # graph mode
  174. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  175. net = MyWhileNet()
  176. idx = Tensor(np.array(0), dtype=ms.int32)
  177. end = Tensor(np.array(2), dtype=ms.int32)
  178. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  179. graph_output = net(idx, end, x)
  180. # pynative mode
  181. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  182. pynative_output = net(idx, end, x)
  183. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  184. @pytest.mark.level0
  185. @pytest.mark.platform_arm_ascend_training
  186. @pytest.mark.platform_x86_ascend_training
  187. @pytest.mark.env_onecard
  188. def test_while_with_param_grad():
  189. class MyWhileNet(nn.Cell):
  190. def __init__(self):
  191. super().__init__()
  192. self.max = P.ReduceMax()
  193. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  194. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  195. def construct(self, idx, end, x):
  196. out = self.zero
  197. while idx < end:
  198. part = x[idx, :, :]
  199. max_num = self.max(part)
  200. x[idx, :, 0:2] = max_num
  201. out = out + x + self.param
  202. idx = idx + 1
  203. return out
  204. class GradNet(nn.Cell):
  205. def __init__(self, net):
  206. super(GradNet, self).__init__()
  207. self.net = net
  208. self.weights = ParameterTuple(net.trainable_params())
  209. def construct(self, a, b, c):
  210. return grad_by_list(self.net, self.weights)(a, b, c)
  211. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  212. while_net = MyWhileNet()
  213. net = GradNet(while_net)
  214. idx = Tensor(np.array(0), dtype=ms.int32)
  215. end = Tensor(np.array(2), dtype=ms.int32)
  216. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  217. graph_output = net(idx, end, x)
  218. expect = np.array([[[2, 2], [2, 2]], [[2, 2], [2, 2]]], dtype=np.int32)
  219. assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001)
  220. @pytest.mark.level0
  221. @pytest.mark.platform_arm_ascend_training
  222. @pytest.mark.platform_x86_ascend_training
  223. @pytest.mark.env_onecard
  224. def test_while_with_param_forward_with_const_branch():
  225. class MyWhileNet(nn.Cell):
  226. def __init__(self):
  227. super().__init__()
  228. self.max = P.ReduceMax()
  229. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  230. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  231. self.reduce = P.ReduceSum()
  232. def construct(self, idx, end, x):
  233. out = self.zero
  234. while idx < end:
  235. if 2 > 1:
  236. out = out + self.param
  237. else:
  238. out = out + idx + self.param
  239. idx = idx + 1
  240. return out
  241. # graph mode
  242. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  243. while_net = MyWhileNet()
  244. net = while_net
  245. idx = Tensor(np.array(0), dtype=ms.int32)
  246. end = Tensor(np.array(4), dtype=ms.int32)
  247. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  248. graph_output = net(idx, end, x)
  249. # pynative mode
  250. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  251. pynative_output = net(idx, end, x)
  252. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  253. @pytest.mark.level0
  254. @pytest.mark.platform_arm_ascend_training
  255. @pytest.mark.platform_x86_ascend_training
  256. @pytest.mark.env_onecard
  257. def test_while_opt_endless():
  258. """endless during optimization case"""
  259. class MyWhileNet(nn.Cell):
  260. def __init__(self):
  261. super().__init__()
  262. self.max = P.ReduceMax()
  263. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  264. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  265. self.reduce = P.ReduceSum()
  266. self.addn = P.AddN()
  267. def construct(self, idx, end, x):
  268. addn1 = self.addn((x, x, x))
  269. out = addn1
  270. while idx < end:
  271. out = self.addn((out, addn1))
  272. idx = idx + 1
  273. out = self.addn((out, x))
  274. return out
  275. class GradNet(nn.Cell):
  276. def __init__(self, net):
  277. super(GradNet, self).__init__()
  278. self.net = net
  279. def construct(self, *inputs):
  280. return grad_all(self.net)(*inputs)
  281. # graph mode
  282. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  283. while_net = MyWhileNet()
  284. net = GradNet(while_net)
  285. idx = Tensor(np.array(0), dtype=ms.int32)
  286. end = Tensor(np.array(4), dtype=ms.int32)
  287. x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32)
  288. graph_output = net(idx, end, x)
  289. # pynative mode
  290. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  291. pynative_output = net(idx, end, x)
  292. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  293. @pytest.mark.skip(reason="not supported yet")
  294. @pytest.mark.level0
  295. @pytest.mark.platform_arm_ascend_training
  296. @pytest.mark.platform_x86_ascend_training
  297. @pytest.mark.env_onecard
  298. def test_no_while_call():
  299. class MyWhileNet(nn.Cell):
  300. def __init__(self):
  301. super().__init__()
  302. self.max = P.ReduceMax()
  303. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  304. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  305. self.reduce = P.ReduceSum()
  306. def construct(self, idx, end, x):
  307. out = self.zero
  308. if 2 > 1:
  309. out = out + self.param
  310. else:
  311. out = out + idx + self.param
  312. return out
  313. # graph mode
  314. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  315. while_net = MyWhileNet()
  316. net = while_net
  317. idx = Tensor(np.array(0), dtype=ms.int32)
  318. end = Tensor(np.array(4), dtype=ms.int32)
  319. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  320. graph_output = net(idx, end, x)
  321. # pynative mode
  322. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  323. pynative_output = net(idx, end, x)
  324. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  325. @pytest.mark.level0
  326. @pytest.mark.platform_arm_ascend_training
  327. @pytest.mark.platform_x86_ascend_training
  328. @pytest.mark.env_onecard
  329. def test_while_with_param_grad_with_const_branch():
  330. class MyWhileNet(nn.Cell):
  331. def __init__(self):
  332. super().__init__()
  333. self.max = P.ReduceMax()
  334. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  335. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  336. self.reduce = P.ReduceSum()
  337. def construct(self, idx, end, x):
  338. out = self.zero
  339. while idx < end:
  340. if 2 > 1:
  341. out = out + self.param
  342. else:
  343. out = out + idx + self.param
  344. idx = idx + 1
  345. return out
  346. class GradNet(nn.Cell):
  347. def __init__(self, net):
  348. super(GradNet, self).__init__()
  349. self.net = net
  350. self.weights = ParameterTuple(net.trainable_params())
  351. def construct(self, a, b, c):
  352. return grad_by_list(self.net, self.weights)(a, b, c)
  353. # graph mode
  354. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  355. while_net = MyWhileNet()
  356. net = GradNet(while_net)
  357. idx = Tensor(np.array(0), dtype=ms.int32)
  358. end = Tensor(np.array(4), dtype=ms.int32)
  359. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  360. graph_output = net(idx, end, x)
  361. # pynative mode
  362. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  363. pynative_output = net(idx, end, x)
  364. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  365. @pytest.mark.skip(reason="not supported yet")
  366. @pytest.mark.level0
  367. @pytest.mark.platform_arm_ascend_training
  368. @pytest.mark.platform_x86_ascend_training
  369. @pytest.mark.env_onecard
  370. def test_for_while_with_param_grad_with_const_branch():
  371. class MyWhileNet(nn.Cell):
  372. def __init__(self):
  373. super().__init__()
  374. self.max = P.ReduceMax()
  375. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  376. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  377. self.reduce = P.ReduceSum()
  378. self.start = Tensor(np.array(0), dtype=ms.int32)
  379. def construct(self, idx, end, x):
  380. out = self.zero
  381. for _ in range(0, 2):
  382. idx = self.start
  383. while idx < end:
  384. if 2 > 1:
  385. out = out + self.param
  386. else:
  387. out = out + idx + self.param
  388. idx = idx + 1
  389. return out
  390. class GradNet(nn.Cell):
  391. def __init__(self, net):
  392. super(GradNet, self).__init__()
  393. self.net = net
  394. self.weights = ParameterTuple(net.trainable_params())
  395. def construct(self, a, b, c):
  396. return grad_by_list(self.net, self.weights)(a, b, c)
  397. # graph mode
  398. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  399. while_net = MyWhileNet()
  400. net = GradNet(while_net)
  401. idx = Tensor(np.array(0), dtype=ms.int32)
  402. end = Tensor(np.array(4), dtype=ms.int32)
  403. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  404. graph_output = net(idx, end, x)
  405. # pynative mode
  406. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  407. pynative_output = net(idx, end, x)
  408. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  409. @pytest.mark.level0
  410. @pytest.mark.platform_arm_ascend_training
  411. @pytest.mark.platform_x86_ascend_training
  412. @pytest.mark.env_onecard
  413. def test_for_while_with_param_grad_basic():
  414. class MyWhileNet(nn.Cell):
  415. def __init__(self):
  416. super().__init__()
  417. self.max = P.ReduceMax()
  418. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  419. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  420. self.reduce = P.ReduceSum()
  421. self.start = Tensor(np.array(0), dtype=ms.int32)
  422. def construct(self, idx, end, x):
  423. out = self.zero
  424. for _ in range(0, 2):
  425. idx = self.start
  426. while idx < end:
  427. out = out + self.param
  428. idx = idx + 1
  429. return out
  430. class GradNet(nn.Cell):
  431. def __init__(self, net):
  432. super(GradNet, self).__init__()
  433. self.net = net
  434. self.weights = ParameterTuple(net.trainable_params())
  435. def construct(self, a, b, c):
  436. return grad_by_list(self.net, self.weights)(a, b, c)
  437. # graph mode
  438. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  439. while_net = MyWhileNet()
  440. net = GradNet(while_net)
  441. idx = Tensor(np.array(0), dtype=ms.int32)
  442. end = Tensor(np.array(4), dtype=ms.int32)
  443. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  444. graph_output = net(idx, end, x)
  445. # pynative mode
  446. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  447. pynative_output = net(idx, end, x)
  448. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  449. @pytest.mark.level0
  450. @pytest.mark.platform_arm_ascend_training
  451. @pytest.mark.platform_x86_ascend_training
  452. @pytest.mark.env_onecard
  453. def test_for_while_with_param_grad_normal():
  454. class MyWhileNet(nn.Cell):
  455. def __init__(self):
  456. super().__init__()
  457. self.max = P.ReduceMax()
  458. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  459. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  460. self.reduce = P.ReduceSum()
  461. self.start = Tensor(np.array(0), dtype=ms.int32)
  462. def construct(self, idx, end, x):
  463. out = x
  464. for _ in range(0, 2):
  465. idx = self.start
  466. while idx < end:
  467. out = out + self.param
  468. idx = idx + 1
  469. return out
  470. class GradNet(nn.Cell):
  471. def __init__(self, net):
  472. super(GradNet, self).__init__()
  473. self.net = net
  474. self.weights = ParameterTuple(net.trainable_params())
  475. def construct(self, a, b, c):
  476. return grad_by_list(self.net, self.weights)(a, b, c)
  477. # graph mode
  478. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  479. while_net = MyWhileNet()
  480. net = GradNet(while_net)
  481. idx = Tensor(np.array(0), dtype=ms.int32)
  482. end = Tensor(np.array(4), dtype=ms.int32)
  483. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  484. graph_output = net(idx, end, x)
  485. # pynative mode
  486. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  487. pynative_output = net(idx, end, x)
  488. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  489. @pytest.mark.level0
  490. @pytest.mark.platform_arm_ascend_training
  491. @pytest.mark.platform_x86_ascend_training
  492. @pytest.mark.env_onecard
  493. def test_while_with_param_basic_grad():
  494. class MyWhileNet(nn.Cell):
  495. def __init__(self):
  496. super().__init__()
  497. self.max = P.ReduceMax()
  498. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  499. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  500. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  501. def construct(self, idx, end, x):
  502. out = self.zero
  503. while idx < end:
  504. out = out + self.param
  505. idx = idx + 1
  506. return out + self.param
  507. class GradNet(nn.Cell):
  508. def __init__(self, net):
  509. super(GradNet, self).__init__()
  510. self.net = net
  511. self.weights = ParameterTuple(net.trainable_params())
  512. def construct(self, a, b, c):
  513. return grad_by_list(self.net, self.weights)(a, b, c)
  514. # graph mode
  515. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  516. while_net = MyWhileNet()
  517. net = GradNet(while_net)
  518. idx = Tensor(np.array(0), dtype=ms.int32)
  519. end = Tensor(np.array(3), dtype=ms.int32)
  520. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  521. graph_output = net(idx, end, x)
  522. # pynative mode
  523. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  524. pynative_output = net(idx, end, x)
  525. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  526. @pytest.mark.level0
  527. @pytest.mark.platform_arm_ascend_training
  528. @pytest.mark.platform_x86_ascend_training
  529. @pytest.mark.env_onecard
  530. def test_while_with_param_basic_grad_mul():
  531. class MyWhileNet(nn.Cell):
  532. def __init__(self):
  533. super().__init__()
  534. self.max = P.ReduceMax()
  535. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  536. self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32)
  537. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  538. def construct(self, idx, end, x):
  539. out = self.zero
  540. while idx < end:
  541. out = out * self.param
  542. idx = idx + 1
  543. return out + self.param
  544. class GradNet(nn.Cell):
  545. def __init__(self, net):
  546. super(GradNet, self).__init__()
  547. self.net = net
  548. self.weights = ParameterTuple(net.trainable_params())
  549. def construct(self, a, b, c):
  550. return grad_by_list(self.net, self.weights)(a, b, c)
  551. # graph mode
  552. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  553. while_net = MyWhileNet()
  554. net = GradNet(while_net)
  555. idx = Tensor(np.array(0), dtype=ms.int32)
  556. end = Tensor(np.array(3), dtype=ms.int32)
  557. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  558. graph_output = net(idx, end, x)
  559. # pynative mode
  560. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  561. pynative_output = net(idx, end, x)
  562. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  563. @pytest.mark.level0
  564. @pytest.mark.platform_arm_ascend_training
  565. @pytest.mark.platform_x86_ascend_training
  566. @pytest.mark.env_onecard
  567. def test_while_with_param_basic_grad_two():
  568. class MyWhileNet(nn.Cell):
  569. def __init__(self):
  570. super().__init__()
  571. self.max = P.ReduceMax()
  572. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  573. self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
  574. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  575. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  576. def construct(self, idx, end, x):
  577. out = self.zero
  578. while idx < end:
  579. out = out + self.param + self.weight
  580. idx = idx + 1
  581. return out + self.param
  582. class GradNet(nn.Cell):
  583. def __init__(self, net):
  584. super(GradNet, self).__init__()
  585. self.net = net
  586. self.weights = ParameterTuple(net.trainable_params())
  587. def construct(self, a, b, c):
  588. return grad_by_list(self.net, self.weights)(a, b, c)
  589. # graph mode
  590. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  591. while_net = MyWhileNet()
  592. net = GradNet(while_net)
  593. idx = Tensor(np.array(0), dtype=ms.int32)
  594. end = Tensor(np.array(3), dtype=ms.int32)
  595. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  596. graph_output = net(idx, end, x)
  597. # pynative mode
  598. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  599. pynative_output = net(idx, end, x)
  600. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  601. assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
  602. @pytest.mark.level0
  603. @pytest.mark.platform_arm_ascend_training
  604. @pytest.mark.platform_x86_ascend_training
  605. @pytest.mark.env_onecard
  606. def test_while_with_param_basic_grad_three():
  607. class MyWhileNet(nn.Cell):
  608. def __init__(self):
  609. super().__init__()
  610. self.max = P.ReduceMax()
  611. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  612. self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
  613. self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key")
  614. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  615. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  616. def construct(self, idx, end, x):
  617. out = self.zero
  618. while idx < end:
  619. out = out + self.param + self.weight + self.key
  620. idx = idx + 1
  621. return out + self.param
  622. class GradNet(nn.Cell):
  623. def __init__(self, net):
  624. super(GradNet, self).__init__()
  625. self.net = net
  626. self.weights = ParameterTuple(net.trainable_params())
  627. def construct(self, a, b, c):
  628. return grad_by_list(self.net, self.weights)(a, b, c)
  629. # graph mode
  630. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  631. while_net = MyWhileNet()
  632. net = GradNet(while_net)
  633. idx = Tensor(np.array(0), dtype=ms.int32)
  634. end = Tensor(np.array(3), dtype=ms.int32)
  635. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  636. graph_output = net(idx, end, x)
  637. # pynative mode
  638. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  639. pynative_output = net(idx, end, x)
  640. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  641. assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
  642. assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)
  643. @pytest.mark.level0
  644. @pytest.mark.platform_arm_ascend_training
  645. @pytest.mark.platform_x86_ascend_training
  646. @pytest.mark.env_onecard
  647. def test_while_if_with_param_grad():
  648. class MyWhileNet(nn.Cell):
  649. def __init__(self):
  650. super().__init__()
  651. self.max = P.ReduceMax()
  652. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  653. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  654. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  655. def construct(self, idx, end, x):
  656. out = self.zero
  657. while idx < end:
  658. if self.max(out) < self.max(x):
  659. out = out + self.param * 2
  660. else:
  661. out = out + self.param
  662. idx = idx + 1
  663. return out + self.param
  664. class GradNet(nn.Cell):
  665. def __init__(self, net):
  666. super(GradNet, self).__init__()
  667. self.net = net
  668. self.weights = ParameterTuple(net.trainable_params())
  669. def construct(self, a, b, c):
  670. return grad_by_list(self.net, self.weights)(a, b, c)
  671. # graph mode
  672. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  673. while_net = MyWhileNet()
  674. net = GradNet(while_net)
  675. idx = Tensor(np.array(0), dtype=ms.int32)
  676. end = Tensor(np.array(3), dtype=ms.int32)
  677. x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
  678. graph_output = net(idx, end, x)
  679. # pynative mode
  680. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  681. pynative_output = net(idx, end, x)
  682. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  683. @pytest.mark.skip(reason="not supported yet")
  684. @pytest.mark.level0
  685. @pytest.mark.platform_arm_ascend_training
  686. @pytest.mark.platform_x86_ascend_training
  687. @pytest.mark.env_onecard
  688. def test_while_with_param_grad_not_enter_while():
  689. class MyWhileNet(nn.Cell):
  690. def __init__(self):
  691. super().__init__()
  692. self.max = P.ReduceMax()
  693. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  694. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  695. def construct(self, idx, end, x):
  696. out = self.zero
  697. while idx < end:
  698. out = out + self.param * 3
  699. idx = idx + 1
  700. return out + self.param
  701. class GradNet(nn.Cell):
  702. def __init__(self, net):
  703. super(GradNet, self).__init__()
  704. self.net = net
  705. self.weights = ParameterTuple(net.trainable_params())
  706. def construct(self, a, b, c):
  707. return grad_by_list(self.net, self.weights)(a, b, c)
  708. # graph mode
  709. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  710. while_net = MyWhileNet()
  711. net = GradNet(while_net)
  712. idx = Tensor(np.array(3), dtype=ms.int32)
  713. end = Tensor(np.array(0), dtype=ms.int32)
  714. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  715. graph_output = net(idx, end, x)
  716. # pynative mode
  717. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  718. pynative_output = net(idx, end, x)
  719. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  720. @pytest.mark.level0
  721. @pytest.mark.platform_arm_ascend_training
  722. @pytest.mark.platform_x86_ascend_training
  723. @pytest.mark.env_onecard
  724. def test_with_param_if_by_if_forward():
  725. class MyIfByIfNet(nn.Cell):
  726. def __init__(self):
  727. super().__init__()
  728. self.max = P.ReduceMax()
  729. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  730. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  731. def construct(self, a, b, x):
  732. out = self.zero
  733. if a < b:
  734. out = out + x + self.param
  735. else:
  736. out = out + x
  737. if a == b:
  738. out = out + x*3 + self.param
  739. else:
  740. out = out + x*2
  741. return out
  742. # graph mode
  743. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  744. if_net = MyIfByIfNet()
  745. net = if_net
  746. idx = Tensor(np.array(0), dtype=ms.int32)
  747. end = Tensor(np.array(4), dtype=ms.int32)
  748. x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
  749. graph_output = net(idx, end, x)
  750. # pynative mode
  751. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  752. pynative_output = net(idx, end, x)
  753. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  754. @pytest.mark.level0
  755. @pytest.mark.platform_arm_ascend_training
  756. @pytest.mark.platform_x86_ascend_training
  757. @pytest.mark.env_onecard
  758. def test_with_param_if_by_if_grad_inputs():
  759. class MyIfByIfNet(nn.Cell):
  760. def __init__(self):
  761. super().__init__()
  762. self.max = P.ReduceMax()
  763. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  764. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  765. def construct(self, a, b, x):
  766. out = self.zero
  767. if a < b:
  768. out = out + x + self.param * 4
  769. if a == b:
  770. out = out + x*3 + self.param * 3
  771. return out
  772. class GradNet(nn.Cell):
  773. def __init__(self, net):
  774. super(GradNet, self).__init__()
  775. self.net = net
  776. def construct(self, *inputs):
  777. return grad_all(self.net)(*inputs)
  778. # graph mode
  779. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  780. if_net = MyIfByIfNet()
  781. net = GradNet(if_net)
  782. idx = Tensor(np.array(0), dtype=ms.int32)
  783. end = Tensor(np.array(0), dtype=ms.int32)
  784. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  785. graph_output = net(idx, end, x)
  786. # pynative mode
  787. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  788. pynative_output = net(idx, end, x)
  789. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  790. assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
  791. assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)
  792. @pytest.mark.level0
  793. @pytest.mark.platform_arm_ascend_training
  794. @pytest.mark.platform_x86_ascend_training
  795. @pytest.mark.env_onecard
  796. def test_with_param_if_by_if_grad_parameter():
  797. class MyIfByIfNet(nn.Cell):
  798. def __init__(self):
  799. super().__init__()
  800. self.max = P.ReduceMax()
  801. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  802. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  803. def construct(self, a, b, x):
  804. out = self.zero
  805. if a < b:
  806. out = out + x + self.param * 2
  807. if a == b:
  808. out = out + x*3 + self.param
  809. return out
  810. class GradNet(nn.Cell):
  811. def __init__(self, net):
  812. super(GradNet, self).__init__()
  813. self.net = net
  814. self.weights = ParameterTuple(net.trainable_params())
  815. def construct(self, *inputs):
  816. return grad_by_list(self.net, self.weights)(*inputs)
  817. # graph mode
  818. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  819. if_net = MyIfByIfNet()
  820. net = GradNet(if_net)
  821. idx = Tensor(np.array(0), dtype=ms.int32)
  822. end = Tensor(np.array(2), dtype=ms.int32)
  823. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  824. graph_output = net(idx, end, x)
  825. # pynative mode
  826. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  827. pynative_output = net(idx, end, x)
  828. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  829. @pytest.mark.level0
  830. @pytest.mark.platform_arm_ascend_training
  831. @pytest.mark.platform_x86_ascend_training
  832. @pytest.mark.env_onecard
  833. def test_with_param_if_by_if_grad_param_excute_null():
  834. class MyIfByIfNet(nn.Cell):
  835. def __init__(self):
  836. super().__init__()
  837. self.max = P.ReduceMax()
  838. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  839. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  840. def construct(self, a, b, x):
  841. out = self.zero
  842. if a < b:
  843. out = out + x + self.param * 2
  844. return out
  845. class GradNet(nn.Cell):
  846. def __init__(self, net):
  847. super(GradNet, self).__init__()
  848. self.net = net
  849. self.weights = ParameterTuple(net.trainable_params())
  850. def construct(self, *inputs):
  851. return grad_by_list(self.net, self.weights)(*inputs)
  852. # graph mode
  853. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  854. if_net = MyIfByIfNet()
  855. net = GradNet(if_net)
  856. idx = Tensor(np.array(4), dtype=ms.int32)
  857. end = Tensor(np.array(0), dtype=ms.int32)
  858. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  859. graph_output = net(idx, end, x)
  860. # pynative mode
  861. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  862. pynative_output = net(idx, end, x)
  863. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  864. @pytest.mark.level0
  865. @pytest.mark.platform_arm_ascend_training
  866. @pytest.mark.platform_x86_ascend_training
  867. @pytest.mark.env_onecard
  868. def test_if_by_if_return_inside_grad():
  869. class MyIfByIfNet(nn.Cell):
  870. def __init__(self):
  871. super().__init__()
  872. self.max = P.ReduceMax()
  873. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  874. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  875. def construct(self, a, b, x):
  876. out = self.zero
  877. if a < b:
  878. return out + x + self.param
  879. if a == b:
  880. return out + self.param * 2
  881. return out + self.param * 3
  882. class GradNet(nn.Cell):
  883. def __init__(self, net):
  884. super(GradNet, self).__init__()
  885. self.net = net
  886. self.weights = ParameterTuple(net.trainable_params())
  887. def construct(self, *inputs):
  888. return grad_by_list(self.net, self.weights)(*inputs)
  889. # graph mode
  890. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  891. if_net = MyIfByIfNet()
  892. net = GradNet(if_net)
  893. idx = Tensor(np.array(1), dtype=ms.int32)
  894. end = Tensor(np.array(0), dtype=ms.int32)
  895. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  896. graph_output = net(idx, end, x)
  897. # pynative mode
  898. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  899. pynative_output = net(idx, end, x)
  900. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  901. @pytest.mark.level0
  902. @pytest.mark.platform_arm_ascend_training
  903. @pytest.mark.platform_x86_ascend_training
  904. @pytest.mark.env_onecard
  905. def test_if_by_if_forward():
  906. class MyIfByIfNet(nn.Cell):
  907. def __init__(self):
  908. super().__init__()
  909. self.add = P.Add()
  910. self.sub = P.Sub()
  911. self.mul = P.Mul()
  912. self.div = P.RealDiv()
  913. def construct(self, a, b, x):
  914. if a < b:
  915. a = self.add(a, b)
  916. else:
  917. a = self.sub(a, b)
  918. if a == x:
  919. a = self.mul(a, b)
  920. else:
  921. a = self.div(a, b)
  922. if b == x:
  923. b = self.add(a, b)
  924. else:
  925. b = self.add(a, x)
  926. a = a * b
  927. out = a + b + x
  928. return out
  929. # graph mode
  930. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  931. if_net = MyIfByIfNet()
  932. net = if_net
  933. idx = Tensor(np.array(2), dtype=ms.float32)
  934. end = Tensor(np.array(3), dtype=ms.float32)
  935. x = Tensor(np.array(4), dtype=ms.float32)
  936. graph_output = net(idx, end, x)
  937. # pynative mode
  938. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  939. pynative_output = net(idx, end, x)
  940. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  941. @pytest.mark.level0
  942. @pytest.mark.platform_arm_ascend_training
  943. @pytest.mark.platform_x86_ascend_training
  944. @pytest.mark.env_onecard
  945. def test_if_by_if_forward_control_tuple_switch():
  946. """tuple_get from switch op will generate new switch inside to eliminate tuple_get"""
  947. class Branch3Net(nn.Cell):
  948. def __init__(self):
  949. super().__init__()
  950. self.add = P.Add()
  951. self.sub = P.Sub()
  952. self.mul = P.Mul()
  953. self.div = P.RealDiv()
  954. def construct(self, a, b, x):
  955. if b == x:
  956. b = self.add(a, b)
  957. else:
  958. b = self.add(a, x)
  959. return a, b, x
  960. class Branch2Net(nn.Cell):
  961. def __init__(self):
  962. super().__init__()
  963. self.add = P.Add()
  964. self.sub = P.Sub()
  965. self.mul = P.Mul()
  966. self.div = P.RealDiv()
  967. self.net = Branch3Net()
  968. def construct(self, a, b, x):
  969. if a == x:
  970. a = self.mul(a, b)
  971. else:
  972. a = self.div(a, b)
  973. return self.net(a, b, x)
  974. class MyIfByIfNet(nn.Cell):
  975. def __init__(self):
  976. super().__init__()
  977. self.add = P.Add()
  978. self.sub = P.Sub()
  979. self.mul = P.Mul()
  980. self.div = P.RealDiv()
  981. self.net = Branch2Net()
  982. def construct(self, a, b, x):
  983. if a < b:
  984. a = self.add(a, b)
  985. else:
  986. a = self.sub(a, b)
  987. a, b, x = self.net(a, b, x)
  988. a = a * b
  989. out = a + b + x
  990. return out
  991. # graph mode
  992. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  993. if_net = MyIfByIfNet()
  994. net = if_net
  995. idx = Tensor(np.array(2), dtype=ms.float32)
  996. end = Tensor(np.array(3), dtype=ms.float32)
  997. x = Tensor(np.array(0), dtype=ms.float32)
  998. graph_output = net(idx, end, x)
  999. # pynative mode
  1000. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1001. pynative_output = net(idx, end, x)
  1002. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1003. @pytest.mark.level0
  1004. @pytest.mark.platform_arm_ascend_training
  1005. @pytest.mark.platform_x86_ascend_training
  1006. @pytest.mark.env_onecard
  1007. def test_if_by_if_forward_control_inside_net():
  1008. class Branch3Net(nn.Cell):
  1009. def __init__(self):
  1010. super().__init__()
  1011. self.add = P.Add()
  1012. self.sub = P.Sub()
  1013. self.mul = P.Mul()
  1014. self.div = P.RealDiv()
  1015. def construct(self, a, b, x):
  1016. if b == x:
  1017. b = self.add(a, b)
  1018. else:
  1019. b = self.add(a, x)
  1020. a = a * b
  1021. out = a + b + x
  1022. return out
  1023. class Branch2Net(nn.Cell):
  1024. def __init__(self):
  1025. super().__init__()
  1026. self.add = P.Add()
  1027. self.sub = P.Sub()
  1028. self.mul = P.Mul()
  1029. self.div = P.RealDiv()
  1030. self.net = Branch3Net()
  1031. def construct(self, a, b, x):
  1032. if a == x:
  1033. a = self.mul(a, b)
  1034. else:
  1035. a = self.div(a, b)
  1036. return self.net(a, b, x)
  1037. class MyIfByIfNet(nn.Cell):
  1038. def __init__(self):
  1039. super().__init__()
  1040. self.add = P.Add()
  1041. self.sub = P.Sub()
  1042. self.mul = P.Mul()
  1043. self.div = P.RealDiv()
  1044. self.net = Branch2Net()
  1045. def construct(self, a, b, x):
  1046. if a < b:
  1047. a = self.add(a, b)
  1048. else:
  1049. a = self.sub(a, b)
  1050. out = self.net(a, b, x)
  1051. return out
  1052. # graph mode
  1053. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  1054. if_net = MyIfByIfNet()
  1055. net = if_net
  1056. idx = Tensor(np.array(2), dtype=ms.float32)
  1057. end = Tensor(np.array(3), dtype=ms.float32)
  1058. x = Tensor(np.array(0), dtype=ms.float32)
  1059. graph_output = net(idx, end, x)
  1060. # pynative mode
  1061. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1062. pynative_output = net(idx, end, x)
  1063. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1064. @pytest.mark.level0
  1065. @pytest.mark.platform_arm_ascend_training
  1066. @pytest.mark.platform_x86_ascend_training
  1067. @pytest.mark.env_onecard
  1068. def test_if_by_if_forward_use_namespace():
  1069. class MyIfByIfNet(nn.Cell):
  1070. def __init__(self):
  1071. super().__init__()
  1072. self.add = P.Add()
  1073. self.sub = P.Sub()
  1074. self.mul = P.Mul()
  1075. self.div = P.RealDiv()
  1076. def construct(self, a, b, x):
  1077. if a < b:
  1078. a = P.Add()(a, b)
  1079. else:
  1080. a = P.Sub()(a, b)
  1081. if a == x:
  1082. a = P.Mul()(a, b)
  1083. else:
  1084. a = P.RealDiv()(a, b)
  1085. if b == x:
  1086. b = P.Add()(a, b)
  1087. else:
  1088. b = P.Add()(a, x)
  1089. a = a * b
  1090. out = a + b + x
  1091. return out
  1092. # graph mode
  1093. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  1094. if_net = MyIfByIfNet()
  1095. net = if_net
  1096. idx = Tensor(np.array(2), dtype=ms.float32)
  1097. end = Tensor(np.array(3), dtype=ms.float32)
  1098. x = Tensor(np.array(0), dtype=ms.float32)
  1099. graph_output = net(idx, end, x)
  1100. # pynative mode
  1101. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1102. pynative_output = net(idx, end, x)
  1103. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1104. @pytest.mark.level0
  1105. @pytest.mark.platform_arm_ascend_training
  1106. @pytest.mark.platform_x86_ascend_training
  1107. @pytest.mark.env_onecard
  1108. def test_if_by_if_forward_use_global_op():
  1109. class MyIfByIfNet(nn.Cell):
  1110. def __init__(self):
  1111. super().__init__()
  1112. self.add = P.Add()
  1113. self.sub = P.Sub()
  1114. self.mul = P.Mul()
  1115. self.div = P.RealDiv()
  1116. def construct(self, a, b, x):
  1117. add = P.Add()
  1118. sub = P.Sub()
  1119. mul = P.Mul()
  1120. div = P.RealDiv()
  1121. if a < b:
  1122. a = add(a, b)
  1123. else:
  1124. a = sub(a, b)
  1125. if a == x:
  1126. a = mul(a, b)
  1127. else:
  1128. a = div(a, b)
  1129. if b == x:
  1130. b = add(a, b)
  1131. else:
  1132. b = add(a, x)
  1133. a = a * b
  1134. out = a + b + x
  1135. return out
  1136. # graph mode
  1137. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  1138. if_net = MyIfByIfNet()
  1139. net = if_net
  1140. idx = Tensor(np.array(2), dtype=ms.float32)
  1141. end = Tensor(np.array(3), dtype=ms.float32)
  1142. x = Tensor(np.array(0), dtype=ms.float32)
  1143. graph_output = net(idx, end, x)
  1144. # pynative mode
  1145. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1146. pynative_output = net(idx, end, x)
  1147. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1148. @pytest.mark.level0
  1149. @pytest.mark.platform_arm_ascend_training
  1150. @pytest.mark.platform_x86_ascend_training
  1151. @pytest.mark.env_onecard
  1152. def test_for_with_if_by_if_forward():
  1153. class MyIfByIfNet(nn.Cell):
  1154. def __init__(self):
  1155. super().__init__()
  1156. self.add = P.Add()
  1157. self.sub = P.Sub()
  1158. def construct(self, a, b, x):
  1159. for _ in range(0, 4):
  1160. if a < b:
  1161. a = self.add(a, b)
  1162. else:
  1163. b = self.sub(b, x)
  1164. a = a * b
  1165. out = a + b + x
  1166. return out
  1167. # graph mode
  1168. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  1169. if_net = MyIfByIfNet()
  1170. net = if_net
  1171. idx = Tensor(np.array(2), dtype=ms.float32)
  1172. end = Tensor(np.array(3), dtype=ms.float32)
  1173. x = Tensor(np.array(0), dtype=ms.float32)
  1174. graph_output = net(idx, end, x)
  1175. # pynative mode
  1176. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1177. pynative_output = net(idx, end, x)
  1178. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1179. @pytest.mark.level0
  1180. @pytest.mark.platform_arm_ascend_training
  1181. @pytest.mark.platform_x86_ascend_training
  1182. @pytest.mark.env_onecard
  1183. def test_for_with_if_by_if_forward_namespace():
  1184. class MyIfByIfNet(nn.Cell):
  1185. def __init__(self):
  1186. super().__init__()
  1187. self.add = P.Add()
  1188. self.sub = P.Sub()
  1189. self.mul = P.Mul()
  1190. self.div = P.RealDiv()
  1191. def construct(self, a, b, x):
  1192. for _ in range(0, 6):
  1193. if a < b:
  1194. a = P.Add()(a, b)
  1195. else:
  1196. b = P.Sub()(b, x)
  1197. a = a * b
  1198. out = a + b + x
  1199. return out
  1200. # graph mode
  1201. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  1202. if_net = MyIfByIfNet()
  1203. net = if_net
  1204. idx = Tensor(np.array(2), dtype=ms.float32)
  1205. end = Tensor(np.array(3), dtype=ms.float32)
  1206. x = Tensor(np.array(0), dtype=ms.float32)
  1207. graph_output = net(idx, end, x)
  1208. # pynative mode
  1209. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1210. pynative_output = net(idx, end, x)
  1211. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1212. @pytest.mark.level0
  1213. @pytest.mark.platform_arm_ascend_training
  1214. @pytest.mark.platform_x86_ascend_training
  1215. @pytest.mark.env_onecard
  1216. def test_if_by_if_forward_const_branch_inner():
  1217. class MyIfByIfNet(nn.Cell):
  1218. def __init__(self):
  1219. super().__init__()
  1220. self.add = P.Add()
  1221. self.sub = P.Sub()
  1222. self.mul = P.Mul()
  1223. self.div = P.RealDiv()
  1224. def construct(self, a, b, x):
  1225. add = P.Add()
  1226. sub = P.Sub()
  1227. mul = P.Mul()
  1228. div = P.RealDiv()
  1229. if a < b:
  1230. a = add(a, b)
  1231. else:
  1232. a = sub(a, b)
  1233. if 2 > 1:
  1234. a = mul(a, b)
  1235. else:
  1236. a = div(a, b)
  1237. if b == x:
  1238. b = add(a, b)
  1239. else:
  1240. b = add(a, x)
  1241. a = a * b
  1242. out = a + b + x
  1243. return out
  1244. # graph mode
  1245. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  1246. if_net = MyIfByIfNet()
  1247. net = if_net
  1248. idx = Tensor(np.array(2), dtype=ms.float32)
  1249. end = Tensor(np.array(3), dtype=ms.float32)
  1250. x = Tensor(np.array(0), dtype=ms.float32)
  1251. graph_output = net(idx, end, x)
  1252. # pynative mode
  1253. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1254. pynative_output = net(idx, end, x)
  1255. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1256. @pytest.mark.level0
  1257. @pytest.mark.platform_arm_ascend_training
  1258. @pytest.mark.platform_x86_ascend_training
  1259. @pytest.mark.env_onecard
  1260. def test_if_by_if_forward_all_const_branch():
  1261. class MyIfByIfNet(nn.Cell):
  1262. def __init__(self):
  1263. super().__init__()
  1264. self.add = P.Add()
  1265. self.sub = P.Sub()
  1266. self.mul = P.Mul()
  1267. self.div = P.RealDiv()
  1268. def construct(self, a, b, x):
  1269. add = P.Add()
  1270. sub = P.Sub()
  1271. mul = P.Mul()
  1272. div = P.RealDiv()
  1273. if 2 < 12:
  1274. a = add(a, b)
  1275. else:
  1276. a = sub(a, b)
  1277. if 2 > 1:
  1278. a = mul(a, b)
  1279. else:
  1280. a = div(a, b)
  1281. if 2 == 1:
  1282. b = add(a, b)
  1283. else:
  1284. b = add(a, x)
  1285. a = a * b
  1286. out = a + b + x
  1287. return out
  1288. # graph mode
  1289. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  1290. if_net = MyIfByIfNet()
  1291. net = if_net
  1292. idx = Tensor(np.array(2), dtype=ms.float32)
  1293. end = Tensor(np.array(3), dtype=ms.float32)
  1294. x = Tensor(np.array(0), dtype=ms.float32)
  1295. graph_output = net(idx, end, x)
  1296. # pynative mode
  1297. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1298. pynative_output = net(idx, end, x)
  1299. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1300. @pytest.mark.level0
  1301. @pytest.mark.platform_x86_cpu
  1302. @pytest.mark.env_onecard
  1303. def test_if_const_grad():
  1304. class MyNet(nn.Cell):
  1305. def __init__(self):
  1306. super().__init__()
  1307. self.add = P.Add()
  1308. def construct(self, *inputs):
  1309. out = self.add(*inputs)
  1310. return out
  1311. class GradNet(nn.Cell):
  1312. def __init__(self, net):
  1313. super(GradNet, self).__init__()
  1314. self.net = net
  1315. self.weights = ParameterTuple(net.trainable_params())
  1316. def construct(self, *inputs):
  1317. a = 1
  1318. b = 2
  1319. if a > 0:
  1320. b = 1
  1321. a += b
  1322. return grad_by_list(self.net, self.weights)(*inputs)
  1323. context.set_context(mode=context.GRAPH_MODE)
  1324. my_net = MyNet()
  1325. net = GradNet(my_net)
  1326. a = Tensor(np.array(0), dtype=ms.int32)
  1327. b = Tensor(np.array(1), dtype=ms.int32)
  1328. net(a, b)
  1329. @pytest.mark.level0
  1330. @pytest.mark.platform_x86_cpu
  1331. @pytest.mark.env_onecard
  1332. def test_if_by_if_const_grad():
  1333. class MyNet(nn.Cell):
  1334. def __init__(self):
  1335. super().__init__()
  1336. self.add = P.Add()
  1337. def construct(self, *inputs):
  1338. out = self.add(*inputs)
  1339. return out
  1340. class GradNet(nn.Cell):
  1341. def __init__(self, net):
  1342. super(GradNet, self).__init__()
  1343. self.net = net
  1344. self.weights = ParameterTuple(net.trainable_params())
  1345. def construct(self, *inputs):
  1346. a = 1
  1347. b = 2
  1348. if a > 0:
  1349. b = 1
  1350. if a < 0:
  1351. b = 0
  1352. if a == 0:
  1353. b = 3
  1354. a += b
  1355. return grad_by_list(self.net, self.weights)(*inputs)
  1356. context.set_context(mode=context.GRAPH_MODE)
  1357. my_net = MyNet()
  1358. net = GradNet(my_net)
  1359. a = Tensor(np.array(0), dtype=ms.int32)
  1360. b = Tensor(np.array(1), dtype=ms.int32)
  1361. net(a, b)
  1362. @pytest.mark.level0
  1363. @pytest.mark.platform_x86_cpu
  1364. @pytest.mark.env_onecard
  1365. def test_while_const_grad():
  1366. class MyNet(nn.Cell):
  1367. def __init__(self):
  1368. super().__init__()
  1369. self.add = P.Add()
  1370. def construct(self, *inputs):
  1371. out = self.add(*inputs)
  1372. return out
  1373. class GradNet(nn.Cell):
  1374. def __init__(self, net):
  1375. super(GradNet, self).__init__()
  1376. self.net = net
  1377. self.weights = ParameterTuple(net.trainable_params())
  1378. def construct(self, *inputs):
  1379. a = 1
  1380. while a > 1:
  1381. a = a - 1
  1382. return grad_by_list(self.net, self.weights)(*inputs)
  1383. context.set_context(mode=context.GRAPH_MODE)
  1384. my_net = MyNet()
  1385. net = GradNet(my_net)
  1386. a = Tensor(np.array(0), dtype=ms.int32)
  1387. b = Tensor(np.array(1), dtype=ms.int32)
  1388. net(a, b)
  1389. @pytest.mark.level0
  1390. @pytest.mark.platform_x86_cpu
  1391. @pytest.mark.env_onecard
  1392. def test_if_by_while_const_grad():
  1393. class MyNet(nn.Cell):
  1394. def __init__(self):
  1395. super().__init__()
  1396. self.add = P.Add()
  1397. def construct(self, *inputs):
  1398. out = self.add(*inputs)
  1399. return out
  1400. class GradNet(nn.Cell):
  1401. def __init__(self, net):
  1402. super(GradNet, self).__init__()
  1403. self.net = net
  1404. self.weights = ParameterTuple(net.trainable_params())
  1405. def construct(self, *inputs):
  1406. a = 1
  1407. b = 2
  1408. if a > 0:
  1409. b = 0
  1410. while a > 1:
  1411. a = a - 1
  1412. a += b
  1413. return grad_by_list(self.net, self.weights)(*inputs)
  1414. context.set_context(mode=context.GRAPH_MODE)
  1415. my_net = MyNet()
  1416. net = GradNet(my_net)
  1417. a = Tensor(np.array(0), dtype=ms.int32)
  1418. b = Tensor(np.array(1), dtype=ms.int32)
  1419. net(a, b)