You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cont_grad.py 53 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test control ops """
  16. import numpy as np
  17. import pytest
  18. from mindspore import dtype as ms
  19. from mindspore import Tensor
  20. from mindspore import context
  21. from mindspore import nn
  22. from mindspore.common.parameter import Parameter, ParameterTuple
  23. from mindspore.ops import composite as C
  24. from mindspore.ops import operations as P
  25. # from tests.vm_impl.math_ops_vm_impl import *
  26. # from tests.vm_impl.vm_interface import *
  27. # from tests.vm_impl import *
  28. # context.set_context(save_graphs=True)
  29. grad_by_list = C.GradOperation(get_by_list=True)
  30. grad_all = C.GradOperation(get_all=True)
  31. def test_while_grad():
  32. class MyWhileNet(nn.Cell):
  33. def __init__(self):
  34. super().__init__()
  35. self.max = P.ReduceMax()
  36. def construct(self, idx, end, x):
  37. while idx < end:
  38. part = x[idx, :, :]
  39. max_num = self.max(part)
  40. x[idx, :, 0:2] = max_num
  41. idx = idx + 1
  42. return x
  43. class GradNet(nn.Cell):
  44. def __init__(self, net):
  45. super(GradNet, self).__init__()
  46. self.net = net
  47. def construct(self, *inputs):
  48. return grad_all(self.net)(*inputs)
  49. # graph mode
  50. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  51. while_net = MyWhileNet()
  52. net = GradNet(while_net)
  53. idx = Tensor(np.array(0), dtype=ms.int32)
  54. end = Tensor(np.array(2), dtype=ms.int32)
  55. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  56. graph_output = net(idx, end, x)
  57. # pynative mode
  58. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  59. pynative_output = net(idx, end, x)
  60. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  61. assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
  62. assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)
  63. @pytest.mark.level0
  64. @pytest.mark.platform_arm_ascend_training
  65. @pytest.mark.platform_x86_ascend_training
  66. @pytest.mark.env_onecard
  67. def test_while_with_const_param_grad():
  68. class MyWhileNet(nn.Cell):
  69. def __init__(self):
  70. super().__init__()
  71. self.mul = P.Mul()
  72. self.add = P.Add()
  73. def construct(self, x, y):
  74. while x < y:
  75. z = self.mul(x, x)
  76. x = self.add(z, 1)
  77. return x
  78. class GradNet(nn.Cell):
  79. def __init__(self, net):
  80. super(GradNet, self).__init__()
  81. self.net = net
  82. def construct(self, *inputs):
  83. return grad_all(self.net)(*inputs)
  84. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  85. while_net = MyWhileNet()
  86. net = GradNet(while_net)
  87. idx = Tensor([1.1], dtype=ms.float32)
  88. end = Tensor([8.0], dtype=ms.float32)
  89. graph_output = net(idx, end)
  90. expect_one = np.array([1.14433983e+02], dtype=np.float32)
  91. expect_two = np.array([0], dtype=np.float32)
  92. assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001)
  93. assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001)
  94. def test_while_with_variable_grad():
  95. class MyWhileNet(nn.Cell):
  96. def __init__(self):
  97. super().__init__()
  98. self.mul = P.Mul()
  99. self.add = P.Add()
  100. def construct(self, x, y):
  101. while x < y:
  102. z = self.mul(x, x)
  103. x = self.add(z, y)
  104. return x
  105. class GradNet(nn.Cell):
  106. def __init__(self, net):
  107. super(GradNet, self).__init__()
  108. self.net = net
  109. def construct(self, *inputs):
  110. return grad_all(self.net)(*inputs)
  111. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  112. while_net = MyWhileNet()
  113. net = GradNet(while_net)
  114. idx = Tensor([1.1], dtype=ms.float32)
  115. end = Tensor([8.0], dtype=ms.float32)
  116. graph_output = net(idx, end)
  117. expect_one = np.array([2.20000005e+00], dtype=np.float32)
  118. expect_two = np.array([1.00000000e+00], dtype=np.float32)
  119. assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001)
  120. assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001)
  121. @pytest.mark.level0
  122. @pytest.mark.platform_arm_ascend_training
  123. @pytest.mark.platform_x86_ascend_training
  124. @pytest.mark.env_onecard
  125. def test_while_with_param_forward():
  126. class MyWhileNet(nn.Cell):
  127. def __init__(self):
  128. super().__init__()
  129. self.max = P.ReduceMax()
  130. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  131. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  132. def construct(self, idx, end, x):
  133. out = self.zero
  134. while idx < end:
  135. part = x[idx, :, :]
  136. max_num = self.max(part)
  137. x[idx, :, 0:2] = max_num
  138. out = out + x + self.param
  139. idx = idx + 1
  140. return out
  141. # graph mode
  142. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  143. net = MyWhileNet()
  144. idx = Tensor(np.array(0), dtype=ms.int32)
  145. end = Tensor(np.array(2), dtype=ms.int32)
  146. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  147. graph_output = net(idx, end, x)
  148. expect = np.array([[[6, 8], [10, 12]], [[19, 22], [25, 28]]], dtype=np.int32)
  149. assert np.allclose(graph_output.asnumpy(), expect, 0.0001, 0.0001)
  150. def test_while_endless_case():
  151. """endless case when optimization"""
  152. class MyWhileNet(nn.Cell):
  153. def __init__(self):
  154. super().__init__()
  155. self.max = P.ReduceMax()
  156. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  157. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  158. def construct(self, idx, end, x):
  159. out = self.zero
  160. while idx < end:
  161. part = x[idx, :, :]
  162. out = out + part
  163. idx = idx + 1
  164. return out
  165. # graph mode
  166. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  167. net = MyWhileNet()
  168. idx = Tensor(np.array(0), dtype=ms.int32)
  169. end = Tensor(np.array(2), dtype=ms.int32)
  170. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  171. graph_output = net(idx, end, x)
  172. # pynative mode
  173. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  174. pynative_output = net(idx, end, x)
  175. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  176. @pytest.mark.level0
  177. @pytest.mark.platform_arm_ascend_training
  178. @pytest.mark.platform_x86_ascend_training
  179. @pytest.mark.env_onecard
  180. def test_while_with_param_grad():
  181. class MyWhileNet(nn.Cell):
  182. def __init__(self):
  183. super().__init__()
  184. self.max = P.ReduceMax()
  185. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  186. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  187. def construct(self, idx, end, x):
  188. out = self.zero
  189. while idx < end:
  190. part = x[idx, :, :]
  191. max_num = self.max(part)
  192. x[idx, :, 0:2] = max_num
  193. out = out + x + self.param
  194. idx = idx + 1
  195. return out
  196. class GradNet(nn.Cell):
  197. def __init__(self, net):
  198. super(GradNet, self).__init__()
  199. self.net = net
  200. self.weights = ParameterTuple(net.trainable_params())
  201. def construct(self, a, b, c):
  202. return grad_by_list(self.net, self.weights)(a, b, c)
  203. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  204. while_net = MyWhileNet()
  205. net = GradNet(while_net)
  206. idx = Tensor(np.array(0), dtype=ms.int32)
  207. end = Tensor(np.array(2), dtype=ms.int32)
  208. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  209. graph_output = net(idx, end, x)
  210. expect = np.array([[[2, 2], [2, 2]], [[2, 2], [2, 2]]], dtype=np.int32)
  211. assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001)
  212. def test_while_with_param_forward_with_const_branch():
  213. class MyWhileNet(nn.Cell):
  214. def __init__(self):
  215. super().__init__()
  216. self.max = P.ReduceMax()
  217. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  218. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  219. self.reduce = P.ReduceSum()
  220. def construct(self, idx, end, x):
  221. out = self.zero
  222. while idx < end:
  223. if 2 > 1:
  224. out = out + self.param
  225. else:
  226. out = out + idx + self.param
  227. idx = idx + 1
  228. return out
  229. # graph mode
  230. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  231. while_net = MyWhileNet()
  232. net = while_net
  233. idx = Tensor(np.array(0), dtype=ms.int32)
  234. end = Tensor(np.array(4), dtype=ms.int32)
  235. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  236. graph_output = net(idx, end, x)
  237. # pynative mode
  238. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  239. pynative_output = net(idx, end, x)
  240. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  241. def test_while_opt_endless():
  242. """endless during optimization case"""
  243. class MyWhileNet(nn.Cell):
  244. def __init__(self):
  245. super().__init__()
  246. self.max = P.ReduceMax()
  247. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  248. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  249. self.reduce = P.ReduceSum()
  250. self.addn = P.AddN()
  251. def construct(self, idx, end, x):
  252. addn1 = self.addn((x, x, x))
  253. out = addn1
  254. while idx < end:
  255. out = self.addn((out, addn1))
  256. idx = idx + 1
  257. out = self.addn((out, x))
  258. return out
  259. class GradNet(nn.Cell):
  260. def __init__(self, net):
  261. super(GradNet, self).__init__()
  262. self.net = net
  263. def construct(self, *inputs):
  264. return grad_all(self.net)(*inputs)
  265. # graph mode
  266. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  267. while_net = MyWhileNet()
  268. net = GradNet(while_net)
  269. idx = Tensor(np.array(0), dtype=ms.int32)
  270. end = Tensor(np.array(4), dtype=ms.int32)
  271. x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32)
  272. graph_output = net(idx, end, x)
  273. # pynative mode
  274. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  275. pynative_output = net(idx, end, x)
  276. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  277. def test_no_while_call():
  278. class MyWhileNet(nn.Cell):
  279. def __init__(self):
  280. super().__init__()
  281. self.max = P.ReduceMax()
  282. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  283. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  284. self.reduce = P.ReduceSum()
  285. def construct(self, idx, end, x):
  286. out = self.zero
  287. if 2 > 1:
  288. out = out + self.param
  289. else:
  290. out = out + idx + self.param
  291. return out
  292. # graph mode
  293. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  294. while_net = MyWhileNet()
  295. net = while_net
  296. idx = Tensor(np.array(0), dtype=ms.int32)
  297. end = Tensor(np.array(4), dtype=ms.int32)
  298. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  299. graph_output = net(idx, end, x)
  300. # pynative mode
  301. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  302. pynative_output = net(idx, end, x)
  303. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  304. def test_while_with_param_grad_with_const_branch():
  305. class MyWhileNet(nn.Cell):
  306. def __init__(self):
  307. super().__init__()
  308. self.max = P.ReduceMax()
  309. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  310. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  311. self.reduce = P.ReduceSum()
  312. def construct(self, idx, end, x):
  313. out = self.zero
  314. while idx < end:
  315. if 2 > 1:
  316. out = out + self.param
  317. else:
  318. out = out + idx + self.param
  319. idx = idx + 1
  320. return out
  321. class GradNet(nn.Cell):
  322. def __init__(self, net):
  323. super(GradNet, self).__init__()
  324. self.net = net
  325. self.weights = ParameterTuple(net.trainable_params())
  326. def construct(self, a, b, c):
  327. return grad_by_list(self.net, self.weights)(a, b, c)
  328. # graph mode
  329. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  330. while_net = MyWhileNet()
  331. net = GradNet(while_net)
  332. idx = Tensor(np.array(0), dtype=ms.int32)
  333. end = Tensor(np.array(4), dtype=ms.int32)
  334. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  335. graph_output = net(idx, end, x)
  336. # pynative mode
  337. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  338. pynative_output = net(idx, end, x)
  339. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  340. def test_for_while_with_param_grad_with_const_branch():
  341. class MyWhileNet(nn.Cell):
  342. def __init__(self):
  343. super().__init__()
  344. self.max = P.ReduceMax()
  345. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  346. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  347. self.reduce = P.ReduceSum()
  348. self.start = Tensor(np.array(0), dtype=ms.int32)
  349. def construct(self, idx, end, x):
  350. out = self.zero
  351. for _ in range(0, 2):
  352. idx = self.start
  353. while idx < end:
  354. if 2 > 1:
  355. out = out + self.param
  356. else:
  357. out = out + idx + self.param
  358. idx = idx + 1
  359. return out
  360. class GradNet(nn.Cell):
  361. def __init__(self, net):
  362. super(GradNet, self).__init__()
  363. self.net = net
  364. self.weights = ParameterTuple(net.trainable_params())
  365. def construct(self, a, b, c):
  366. return grad_by_list(self.net, self.weights)(a, b, c)
  367. # graph mode
  368. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  369. while_net = MyWhileNet()
  370. net = GradNet(while_net)
  371. idx = Tensor(np.array(0), dtype=ms.int32)
  372. end = Tensor(np.array(4), dtype=ms.int32)
  373. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  374. graph_output = net(idx, end, x)
  375. # pynative mode
  376. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  377. pynative_output = net(idx, end, x)
  378. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  379. def test_for_while_with_param_grad_basic():
  380. class MyWhileNet(nn.Cell):
  381. def __init__(self):
  382. super().__init__()
  383. self.max = P.ReduceMax()
  384. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  385. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  386. self.reduce = P.ReduceSum()
  387. self.start = Tensor(np.array(0), dtype=ms.int32)
  388. def construct(self, idx, end, x):
  389. out = self.zero
  390. for _ in range(0, 2):
  391. idx = self.start
  392. while idx < end:
  393. out = out + self.param
  394. idx = idx + 1
  395. return out
  396. class GradNet(nn.Cell):
  397. def __init__(self, net):
  398. super(GradNet, self).__init__()
  399. self.net = net
  400. self.weights = ParameterTuple(net.trainable_params())
  401. def construct(self, a, b, c):
  402. return grad_by_list(self.net, self.weights)(a, b, c)
  403. # graph mode
  404. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  405. while_net = MyWhileNet()
  406. net = GradNet(while_net)
  407. idx = Tensor(np.array(0), dtype=ms.int32)
  408. end = Tensor(np.array(4), dtype=ms.int32)
  409. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  410. graph_output = net(idx, end, x)
  411. # pynative mode
  412. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  413. pynative_output = net(idx, end, x)
  414. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  415. def test_for_while_with_param_grad_normal():
  416. class MyWhileNet(nn.Cell):
  417. def __init__(self):
  418. super().__init__()
  419. self.max = P.ReduceMax()
  420. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  421. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  422. self.reduce = P.ReduceSum()
  423. self.start = Tensor(np.array(0), dtype=ms.int32)
  424. def construct(self, idx, end, x):
  425. out = x
  426. for _ in range(0, 2):
  427. idx = self.start
  428. while idx < end:
  429. out = out + self.param
  430. idx = idx + 1
  431. return out
  432. class GradNet(nn.Cell):
  433. def __init__(self, net):
  434. super(GradNet, self).__init__()
  435. self.net = net
  436. self.weights = ParameterTuple(net.trainable_params())
  437. def construct(self, a, b, c):
  438. return grad_by_list(self.net, self.weights)(a, b, c)
  439. # graph mode
  440. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  441. while_net = MyWhileNet()
  442. net = GradNet(while_net)
  443. idx = Tensor(np.array(0), dtype=ms.int32)
  444. end = Tensor(np.array(4), dtype=ms.int32)
  445. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  446. graph_output = net(idx, end, x)
  447. # pynative mode
  448. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  449. pynative_output = net(idx, end, x)
  450. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  451. def test_while_with_param_basic_grad():
  452. class MyWhileNet(nn.Cell):
  453. def __init__(self):
  454. super().__init__()
  455. self.max = P.ReduceMax()
  456. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  457. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  458. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  459. def construct(self, idx, end, x):
  460. out = self.zero
  461. while idx < end:
  462. out = out + self.param
  463. idx = idx + 1
  464. return out + self.param
  465. class GradNet(nn.Cell):
  466. def __init__(self, net):
  467. super(GradNet, self).__init__()
  468. self.net = net
  469. self.weights = ParameterTuple(net.trainable_params())
  470. def construct(self, a, b, c):
  471. return grad_by_list(self.net, self.weights)(a, b, c)
  472. # graph mode
  473. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  474. while_net = MyWhileNet()
  475. net = GradNet(while_net)
  476. idx = Tensor(np.array(0), dtype=ms.int32)
  477. end = Tensor(np.array(3), dtype=ms.int32)
  478. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  479. graph_output = net(idx, end, x)
  480. # pynative mode
  481. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  482. pynative_output = net(idx, end, x)
  483. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  484. def test_while_with_param_basic_grad_mul():
  485. class MyWhileNet(nn.Cell):
  486. def __init__(self):
  487. super().__init__()
  488. self.max = P.ReduceMax()
  489. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  490. self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32)
  491. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  492. def construct(self, idx, end, x):
  493. out = self.zero
  494. while idx < end:
  495. out = out * self.param
  496. idx = idx + 1
  497. return out + self.param
  498. class GradNet(nn.Cell):
  499. def __init__(self, net):
  500. super(GradNet, self).__init__()
  501. self.net = net
  502. self.weights = ParameterTuple(net.trainable_params())
  503. def construct(self, a, b, c):
  504. return grad_by_list(self.net, self.weights)(a, b, c)
  505. # graph mode
  506. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  507. while_net = MyWhileNet()
  508. net = GradNet(while_net)
  509. idx = Tensor(np.array(0), dtype=ms.int32)
  510. end = Tensor(np.array(3), dtype=ms.int32)
  511. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  512. graph_output = net(idx, end, x)
  513. # pynative mode
  514. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  515. pynative_output = net(idx, end, x)
  516. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  517. def test_while_with_param_basic_grad_two():
  518. class MyWhileNet(nn.Cell):
  519. def __init__(self):
  520. super().__init__()
  521. self.max = P.ReduceMax()
  522. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  523. self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
  524. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  525. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  526. def construct(self, idx, end, x):
  527. out = self.zero
  528. while idx < end:
  529. out = out + self.param + self.weight
  530. idx = idx + 1
  531. return out + self.param
  532. class GradNet(nn.Cell):
  533. def __init__(self, net):
  534. super(GradNet, self).__init__()
  535. self.net = net
  536. self.weights = ParameterTuple(net.trainable_params())
  537. def construct(self, a, b, c):
  538. return grad_by_list(self.net, self.weights)(a, b, c)
  539. # graph mode
  540. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  541. while_net = MyWhileNet()
  542. net = GradNet(while_net)
  543. idx = Tensor(np.array(0), dtype=ms.int32)
  544. end = Tensor(np.array(3), dtype=ms.int32)
  545. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  546. graph_output = net(idx, end, x)
  547. # pynative mode
  548. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  549. pynative_output = net(idx, end, x)
  550. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  551. assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
  552. def test_while_with_param_basic_grad_three():
  553. class MyWhileNet(nn.Cell):
  554. def __init__(self):
  555. super().__init__()
  556. self.max = P.ReduceMax()
  557. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  558. self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
  559. self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key")
  560. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  561. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  562. def construct(self, idx, end, x):
  563. out = self.zero
  564. while idx < end:
  565. out = out + self.param + self.weight + self.key
  566. idx = idx + 1
  567. return out + self.param
  568. class GradNet(nn.Cell):
  569. def __init__(self, net):
  570. super(GradNet, self).__init__()
  571. self.net = net
  572. self.weights = ParameterTuple(net.trainable_params())
  573. def construct(self, a, b, c):
  574. return grad_by_list(self.net, self.weights)(a, b, c)
  575. # graph mode
  576. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  577. while_net = MyWhileNet()
  578. net = GradNet(while_net)
  579. idx = Tensor(np.array(0), dtype=ms.int32)
  580. end = Tensor(np.array(3), dtype=ms.int32)
  581. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  582. graph_output = net(idx, end, x)
  583. # pynative mode
  584. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  585. pynative_output = net(idx, end, x)
  586. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  587. assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
  588. assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)
  589. def test_while_if_with_param_grad():
  590. class MyWhileNet(nn.Cell):
  591. def __init__(self):
  592. super().__init__()
  593. self.max = P.ReduceMax()
  594. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  595. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  596. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  597. def construct(self, idx, end, x):
  598. out = self.zero
  599. while idx < end:
  600. if self.max(out) < self.max(x):
  601. out = out + self.param * 2
  602. else:
  603. out = out + self.param
  604. idx = idx + 1
  605. return out + self.param
  606. class GradNet(nn.Cell):
  607. def __init__(self, net):
  608. super(GradNet, self).__init__()
  609. self.net = net
  610. self.weights = ParameterTuple(net.trainable_params())
  611. def construct(self, a, b, c):
  612. return grad_by_list(self.net, self.weights)(a, b, c)
  613. # graph mode
  614. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  615. while_net = MyWhileNet()
  616. net = GradNet(while_net)
  617. idx = Tensor(np.array(0), dtype=ms.int32)
  618. end = Tensor(np.array(3), dtype=ms.int32)
  619. x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
  620. graph_output = net(idx, end, x)
  621. # pynative mode
  622. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  623. pynative_output = net(idx, end, x)
  624. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  625. def test_while_with_param_grad_not_enter_while():
  626. class MyWhileNet(nn.Cell):
  627. def __init__(self):
  628. super().__init__()
  629. self.max = P.ReduceMax()
  630. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  631. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  632. def construct(self, idx, end, x):
  633. out = self.zero
  634. while idx < end:
  635. out = out + self.param * 3
  636. idx = idx + 1
  637. return out + self.param
  638. class GradNet(nn.Cell):
  639. def __init__(self, net):
  640. super(GradNet, self).__init__()
  641. self.net = net
  642. self.weights = ParameterTuple(net.trainable_params())
  643. def construct(self, a, b, c):
  644. return grad_by_list(self.net, self.weights)(a, b, c)
  645. # graph mode
  646. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  647. while_net = MyWhileNet()
  648. net = GradNet(while_net)
  649. idx = Tensor(np.array(3), dtype=ms.int32)
  650. end = Tensor(np.array(0), dtype=ms.int32)
  651. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  652. graph_output = net(idx, end, x)
  653. # pynative mode
  654. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  655. pynative_output = net(idx, end, x)
  656. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  657. def test_with_param_if_by_if_forward():
  658. class MyIfByIfNet(nn.Cell):
  659. def __init__(self):
  660. super().__init__()
  661. self.max = P.ReduceMax()
  662. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  663. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  664. def construct(self, a, b, x):
  665. out = self.zero
  666. if a < b:
  667. out = out + x + self.param
  668. else:
  669. out = out + x
  670. if a == b:
  671. out = out + x*3 + self.param
  672. else:
  673. out = out + x*2
  674. return out
  675. # graph mode
  676. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  677. if_net = MyIfByIfNet()
  678. net = if_net
  679. idx = Tensor(np.array(0), dtype=ms.int32)
  680. end = Tensor(np.array(4), dtype=ms.int32)
  681. x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
  682. graph_output = net(idx, end, x)
  683. # pynative mode
  684. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  685. pynative_output = net(idx, end, x)
  686. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  687. def test_with_param_if_by_if_grad_inputs():
  688. class MyIfByIfNet(nn.Cell):
  689. def __init__(self):
  690. super().__init__()
  691. self.max = P.ReduceMax()
  692. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  693. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  694. def construct(self, a, b, x):
  695. out = self.zero
  696. if a < b:
  697. out = out + x + self.param * 4
  698. if a == b:
  699. out = out + x*3 + self.param * 3
  700. return out
  701. class GradNet(nn.Cell):
  702. def __init__(self, net):
  703. super(GradNet, self).__init__()
  704. self.net = net
  705. def construct(self, *inputs):
  706. return grad_all(self.net)(*inputs)
  707. # graph mode
  708. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  709. if_net = MyIfByIfNet()
  710. net = GradNet(if_net)
  711. idx = Tensor(np.array(0), dtype=ms.int32)
  712. end = Tensor(np.array(0), dtype=ms.int32)
  713. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  714. graph_output = net(idx, end, x)
  715. # pynative mode
  716. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  717. pynative_output = net(idx, end, x)
  718. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  719. assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
  720. assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)
  721. def test_with_param_if_by_if_grad_parameter():
  722. class MyIfByIfNet(nn.Cell):
  723. def __init__(self):
  724. super().__init__()
  725. self.max = P.ReduceMax()
  726. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  727. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  728. def construct(self, a, b, x):
  729. out = self.zero
  730. if a < b:
  731. out = out + x + self.param * 2
  732. if a == b:
  733. out = out + x*3 + self.param
  734. return out
  735. class GradNet(nn.Cell):
  736. def __init__(self, net):
  737. super(GradNet, self).__init__()
  738. self.net = net
  739. self.weights = ParameterTuple(net.trainable_params())
  740. def construct(self, *inputs):
  741. return grad_by_list(self.net, self.weights)(*inputs)
  742. # graph mode
  743. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  744. if_net = MyIfByIfNet()
  745. net = GradNet(if_net)
  746. idx = Tensor(np.array(0), dtype=ms.int32)
  747. end = Tensor(np.array(2), dtype=ms.int32)
  748. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  749. graph_output = net(idx, end, x)
  750. # pynative mode
  751. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  752. pynative_output = net(idx, end, x)
  753. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  754. def test_with_param_if_by_if_grad_param_excute_null():
  755. class MyIfByIfNet(nn.Cell):
  756. def __init__(self):
  757. super().__init__()
  758. self.max = P.ReduceMax()
  759. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  760. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  761. def construct(self, a, b, x):
  762. out = self.zero
  763. if a < b:
  764. out = out + x + self.param * 2
  765. return out
  766. class GradNet(nn.Cell):
  767. def __init__(self, net):
  768. super(GradNet, self).__init__()
  769. self.net = net
  770. self.weights = ParameterTuple(net.trainable_params())
  771. def construct(self, *inputs):
  772. return grad_by_list(self.net, self.weights)(*inputs)
  773. # graph mode
  774. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  775. if_net = MyIfByIfNet()
  776. net = GradNet(if_net)
  777. idx = Tensor(np.array(4), dtype=ms.int32)
  778. end = Tensor(np.array(0), dtype=ms.int32)
  779. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  780. graph_output = net(idx, end, x)
  781. # pynative mode
  782. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  783. pynative_output = net(idx, end, x)
  784. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  785. def test_if_by_if_return_inside_grad():
  786. class MyIfByIfNet(nn.Cell):
  787. def __init__(self):
  788. super().__init__()
  789. self.max = P.ReduceMax()
  790. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  791. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  792. def construct(self, a, b, x):
  793. out = self.zero
  794. if a < b:
  795. return out + x + self.param
  796. if a == b:
  797. return out + self.param * 2
  798. return out + self.param * 3
  799. class GradNet(nn.Cell):
  800. def __init__(self, net):
  801. super(GradNet, self).__init__()
  802. self.net = net
  803. self.weights = ParameterTuple(net.trainable_params())
  804. def construct(self, *inputs):
  805. return grad_by_list(self.net, self.weights)(*inputs)
  806. # graph mode
  807. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  808. if_net = MyIfByIfNet()
  809. net = GradNet(if_net)
  810. idx = Tensor(np.array(1), dtype=ms.int32)
  811. end = Tensor(np.array(0), dtype=ms.int32)
  812. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  813. graph_output = net(idx, end, x)
  814. # pynative mode
  815. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  816. pynative_output = net(idx, end, x)
  817. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  818. def test_if_by_if_forward():
  819. class MyIfByIfNet(nn.Cell):
  820. def __init__(self):
  821. super().__init__()
  822. self.add = P.Add()
  823. self.sub = P.Sub()
  824. self.mul = P.Mul()
  825. self.div = P.RealDiv()
  826. def construct(self, a, b, x):
  827. if a < b:
  828. a = self.add(a, b)
  829. else:
  830. a = self.sub(a, b)
  831. if a == x:
  832. a = self.mul(a, b)
  833. else:
  834. a = self.div(a, b)
  835. if b == x:
  836. b = self.add(a, b)
  837. else:
  838. b = self.add(a, x)
  839. a = a * b
  840. out = a + b + x
  841. return out
  842. # graph mode
  843. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  844. if_net = MyIfByIfNet()
  845. net = if_net
  846. idx = Tensor(np.array(2), dtype=ms.float32)
  847. end = Tensor(np.array(3), dtype=ms.float32)
  848. x = Tensor(np.array(4), dtype=ms.float32)
  849. graph_output = net(idx, end, x)
  850. # pynative mode
  851. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  852. pynative_output = net(idx, end, x)
  853. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  854. def test_if_by_if_forward_control_tuple_switch():
  855. """tuple_get from switch op will generate new switch inside to eliminate tuple_get"""
  856. class Branch3Net(nn.Cell):
  857. def __init__(self):
  858. super().__init__()
  859. self.add = P.Add()
  860. self.sub = P.Sub()
  861. self.mul = P.Mul()
  862. self.div = P.RealDiv()
  863. def construct(self, a, b, x):
  864. if b == x:
  865. b = self.add(a, b)
  866. else:
  867. b = self.add(a, x)
  868. return a, b, x
  869. class Branch2Net(nn.Cell):
  870. def __init__(self):
  871. super().__init__()
  872. self.add = P.Add()
  873. self.sub = P.Sub()
  874. self.mul = P.Mul()
  875. self.div = P.RealDiv()
  876. self.net = Branch3Net()
  877. def construct(self, a, b, x):
  878. if a == x:
  879. a = self.mul(a, b)
  880. else:
  881. a = self.div(a, b)
  882. return self.net(a, b, x)
  883. class MyIfByIfNet(nn.Cell):
  884. def __init__(self):
  885. super().__init__()
  886. self.add = P.Add()
  887. self.sub = P.Sub()
  888. self.mul = P.Mul()
  889. self.div = P.RealDiv()
  890. self.net = Branch2Net()
  891. def construct(self, a, b, x):
  892. if a < b:
  893. a = self.add(a, b)
  894. else:
  895. a = self.sub(a, b)
  896. a, b, x = self.net(a, b, x)
  897. a = a * b
  898. out = a + b + x
  899. return out
  900. # graph mode
  901. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  902. if_net = MyIfByIfNet()
  903. net = if_net
  904. idx = Tensor(np.array(2), dtype=ms.float32)
  905. end = Tensor(np.array(3), dtype=ms.float32)
  906. x = Tensor(np.array(0), dtype=ms.float32)
  907. graph_output = net(idx, end, x)
  908. # pynative mode
  909. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  910. pynative_output = net(idx, end, x)
  911. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  912. def test_if_by_if_forward_control_inside_net():
  913. class Branch3Net(nn.Cell):
  914. def __init__(self):
  915. super().__init__()
  916. self.add = P.Add()
  917. self.sub = P.Sub()
  918. self.mul = P.Mul()
  919. self.div = P.RealDiv()
  920. def construct(self, a, b, x):
  921. if b == x:
  922. b = self.add(a, b)
  923. else:
  924. b = self.add(a, x)
  925. a = a * b
  926. out = a + b + x
  927. return out
  928. class Branch2Net(nn.Cell):
  929. def __init__(self):
  930. super().__init__()
  931. self.add = P.Add()
  932. self.sub = P.Sub()
  933. self.mul = P.Mul()
  934. self.div = P.RealDiv()
  935. self.net = Branch3Net()
  936. def construct(self, a, b, x):
  937. if a == x:
  938. a = self.mul(a, b)
  939. else:
  940. a = self.div(a, b)
  941. return self.net(a, b, x)
  942. class MyIfByIfNet(nn.Cell):
  943. def __init__(self):
  944. super().__init__()
  945. self.add = P.Add()
  946. self.sub = P.Sub()
  947. self.mul = P.Mul()
  948. self.div = P.RealDiv()
  949. self.net = Branch2Net()
  950. def construct(self, a, b, x):
  951. if a < b:
  952. a = self.add(a, b)
  953. else:
  954. a = self.sub(a, b)
  955. out = self.net(a, b, x)
  956. return out
  957. # graph mode
  958. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  959. if_net = MyIfByIfNet()
  960. net = if_net
  961. idx = Tensor(np.array(2), dtype=ms.float32)
  962. end = Tensor(np.array(3), dtype=ms.float32)
  963. x = Tensor(np.array(0), dtype=ms.float32)
  964. graph_output = net(idx, end, x)
  965. # pynative mode
  966. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  967. pynative_output = net(idx, end, x)
  968. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  969. def test_if_by_if_forward_use_namespace():
  970. class MyIfByIfNet(nn.Cell):
  971. def __init__(self):
  972. super().__init__()
  973. self.add = P.Add()
  974. self.sub = P.Sub()
  975. self.mul = P.Mul()
  976. self.div = P.RealDiv()
  977. def construct(self, a, b, x):
  978. if a < b:
  979. a = P.Add()(a, b)
  980. else:
  981. a = P.Sub()(a, b)
  982. if a == x:
  983. a = P.Mul()(a, b)
  984. else:
  985. a = P.RealDiv()(a, b)
  986. if b == x:
  987. b = P.Add()(a, b)
  988. else:
  989. b = P.Add()(a, x)
  990. a = a * b
  991. out = a + b + x
  992. return out
  993. # graph mode
  994. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  995. if_net = MyIfByIfNet()
  996. net = if_net
  997. idx = Tensor(np.array(2), dtype=ms.float32)
  998. end = Tensor(np.array(3), dtype=ms.float32)
  999. x = Tensor(np.array(0), dtype=ms.float32)
  1000. graph_output = net(idx, end, x)
  1001. # pynative mode
  1002. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1003. pynative_output = net(idx, end, x)
  1004. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1005. def test_if_by_if_forward_use_global_op():
  1006. class MyIfByIfNet(nn.Cell):
  1007. def __init__(self):
  1008. super().__init__()
  1009. self.add = P.Add()
  1010. self.sub = P.Sub()
  1011. self.mul = P.Mul()
  1012. self.div = P.RealDiv()
  1013. def construct(self, a, b, x):
  1014. add = P.Add()
  1015. sub = P.Sub()
  1016. mul = P.Mul()
  1017. div = P.RealDiv()
  1018. if a < b:
  1019. a = add(a, b)
  1020. else:
  1021. a = sub(a, b)
  1022. if a == x:
  1023. a = mul(a, b)
  1024. else:
  1025. a = div(a, b)
  1026. if b == x:
  1027. b = add(a, b)
  1028. else:
  1029. b = add(a, x)
  1030. a = a * b
  1031. out = a + b + x
  1032. return out
  1033. # graph mode
  1034. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  1035. if_net = MyIfByIfNet()
  1036. net = if_net
  1037. idx = Tensor(np.array(2), dtype=ms.float32)
  1038. end = Tensor(np.array(3), dtype=ms.float32)
  1039. x = Tensor(np.array(0), dtype=ms.float32)
  1040. graph_output = net(idx, end, x)
  1041. # pynative mode
  1042. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1043. pynative_output = net(idx, end, x)
  1044. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1045. def test_for_with_if_by_if_forward():
  1046. class MyIfByIfNet(nn.Cell):
  1047. def __init__(self):
  1048. super().__init__()
  1049. self.add = P.Add()
  1050. self.sub = P.Sub()
  1051. def construct(self, a, b, x):
  1052. for _ in range(0, 4):
  1053. if a < b:
  1054. a = self.add(a, b)
  1055. else:
  1056. b = self.sub(b, x)
  1057. a = a * b
  1058. out = a + b + x
  1059. return out
  1060. # graph mode
  1061. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  1062. if_net = MyIfByIfNet()
  1063. net = if_net
  1064. idx = Tensor(np.array(2), dtype=ms.float32)
  1065. end = Tensor(np.array(3), dtype=ms.float32)
  1066. x = Tensor(np.array(0), dtype=ms.float32)
  1067. graph_output = net(idx, end, x)
  1068. # pynative mode
  1069. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1070. pynative_output = net(idx, end, x)
  1071. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1072. def test_for_with_if_by_if_forward_namespace():
  1073. class MyIfByIfNet(nn.Cell):
  1074. def __init__(self):
  1075. super().__init__()
  1076. self.add = P.Add()
  1077. self.sub = P.Sub()
  1078. self.mul = P.Mul()
  1079. self.div = P.RealDiv()
  1080. def construct(self, a, b, x):
  1081. for _ in range(0, 6):
  1082. if a < b:
  1083. a = P.Add()(a, b)
  1084. else:
  1085. b = P.Sub()(b, x)
  1086. a = a * b
  1087. out = a + b + x
  1088. return out
  1089. # graph mode
  1090. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  1091. if_net = MyIfByIfNet()
  1092. net = if_net
  1093. idx = Tensor(np.array(2), dtype=ms.float32)
  1094. end = Tensor(np.array(3), dtype=ms.float32)
  1095. x = Tensor(np.array(0), dtype=ms.float32)
  1096. graph_output = net(idx, end, x)
  1097. # pynative mode
  1098. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1099. pynative_output = net(idx, end, x)
  1100. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1101. def test_if_by_if_forward_const_branch_inner():
  1102. class MyIfByIfNet(nn.Cell):
  1103. def __init__(self):
  1104. super().__init__()
  1105. self.add = P.Add()
  1106. self.sub = P.Sub()
  1107. self.mul = P.Mul()
  1108. self.div = P.RealDiv()
  1109. def construct(self, a, b, x):
  1110. add = P.Add()
  1111. sub = P.Sub()
  1112. mul = P.Mul()
  1113. div = P.RealDiv()
  1114. if a < b:
  1115. a = add(a, b)
  1116. else:
  1117. a = sub(a, b)
  1118. if 2 > 1:
  1119. a = mul(a, b)
  1120. else:
  1121. a = div(a, b)
  1122. if b == x:
  1123. b = add(a, b)
  1124. else:
  1125. b = add(a, x)
  1126. a = a * b
  1127. out = a + b + x
  1128. return out
  1129. # graph mode
  1130. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  1131. if_net = MyIfByIfNet()
  1132. net = if_net
  1133. idx = Tensor(np.array(2), dtype=ms.float32)
  1134. end = Tensor(np.array(3), dtype=ms.float32)
  1135. x = Tensor(np.array(0), dtype=ms.float32)
  1136. graph_output = net(idx, end, x)
  1137. # pynative mode
  1138. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1139. pynative_output = net(idx, end, x)
  1140. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1141. def test_if_by_if_forward_all_const_branch():
  1142. class MyIfByIfNet(nn.Cell):
  1143. def __init__(self):
  1144. super().__init__()
  1145. self.add = P.Add()
  1146. self.sub = P.Sub()
  1147. self.mul = P.Mul()
  1148. self.div = P.RealDiv()
  1149. def construct(self, a, b, x):
  1150. add = P.Add()
  1151. sub = P.Sub()
  1152. mul = P.Mul()
  1153. div = P.RealDiv()
  1154. if 2 < 12:
  1155. a = add(a, b)
  1156. else:
  1157. a = sub(a, b)
  1158. if 2 > 1:
  1159. a = mul(a, b)
  1160. else:
  1161. a = div(a, b)
  1162. if 2 == 1:
  1163. b = add(a, b)
  1164. else:
  1165. b = add(a, x)
  1166. a = a * b
  1167. out = a + b + x
  1168. return out
  1169. # graph mode
  1170. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  1171. if_net = MyIfByIfNet()
  1172. net = if_net
  1173. idx = Tensor(np.array(2), dtype=ms.float32)
  1174. end = Tensor(np.array(3), dtype=ms.float32)
  1175. x = Tensor(np.array(0), dtype=ms.float32)
  1176. graph_output = net(idx, end, x)
  1177. # pynative mode
  1178. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1179. pynative_output = net(idx, end, x)
  1180. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1181. @pytest.mark.level0
  1182. @pytest.mark.platform_x86_cpu
  1183. @pytest.mark.env_onecard
  1184. def test_if_const_grad():
  1185. class MyNet(nn.Cell):
  1186. def __init__(self):
  1187. super().__init__()
  1188. self.add = P.Add()
  1189. def construct(self, *inputs):
  1190. out = self.add(*inputs)
  1191. return out
  1192. class GradNet(nn.Cell):
  1193. def __init__(self, net):
  1194. super(GradNet, self).__init__()
  1195. self.net = net
  1196. self.weights = ParameterTuple(net.trainable_params())
  1197. def construct(self, *inputs):
  1198. a = 1
  1199. b = 2
  1200. if a > 0:
  1201. b = 1
  1202. a += b
  1203. return grad_by_list(self.net, self.weights)(*inputs)
  1204. context.set_context(mode=context.GRAPH_MODE)
  1205. my_net = MyNet()
  1206. net = GradNet(my_net)
  1207. a = Tensor(np.array(0), dtype=ms.int32)
  1208. b = Tensor(np.array(1), dtype=ms.int32)
  1209. net(a, b)
  1210. @pytest.mark.level0
  1211. @pytest.mark.platform_x86_cpu
  1212. @pytest.mark.env_onecard
  1213. def test_if_by_if_const_grad():
  1214. class MyNet(nn.Cell):
  1215. def __init__(self):
  1216. super().__init__()
  1217. self.add = P.Add()
  1218. def construct(self, *inputs):
  1219. out = self.add(*inputs)
  1220. return out
  1221. class GradNet(nn.Cell):
  1222. def __init__(self, net):
  1223. super(GradNet, self).__init__()
  1224. self.net = net
  1225. self.weights = ParameterTuple(net.trainable_params())
  1226. def construct(self, *inputs):
  1227. a = 1
  1228. b = 2
  1229. if a > 0:
  1230. b = 1
  1231. if a < 0:
  1232. b = 0
  1233. if a == 0:
  1234. b = 3
  1235. a += b
  1236. return grad_by_list(self.net, self.weights)(*inputs)
  1237. context.set_context(mode=context.GRAPH_MODE)
  1238. my_net = MyNet()
  1239. net = GradNet(my_net)
  1240. a = Tensor(np.array(0), dtype=ms.int32)
  1241. b = Tensor(np.array(1), dtype=ms.int32)
  1242. net(a, b)
  1243. @pytest.mark.level0
  1244. @pytest.mark.platform_x86_cpu
  1245. @pytest.mark.env_onecard
  1246. def test_while_const_grad():
  1247. class MyNet(nn.Cell):
  1248. def __init__(self):
  1249. super().__init__()
  1250. self.add = P.Add()
  1251. def construct(self, *inputs):
  1252. out = self.add(*inputs)
  1253. return out
  1254. class GradNet(nn.Cell):
  1255. def __init__(self, net):
  1256. super(GradNet, self).__init__()
  1257. self.net = net
  1258. self.weights = ParameterTuple(net.trainable_params())
  1259. def construct(self, *inputs):
  1260. a = 1
  1261. while a > 1:
  1262. a = a - 1
  1263. return grad_by_list(self.net, self.weights)(*inputs)
  1264. context.set_context(mode=context.GRAPH_MODE)
  1265. my_net = MyNet()
  1266. net = GradNet(my_net)
  1267. a = Tensor(np.array(0), dtype=ms.int32)
  1268. b = Tensor(np.array(1), dtype=ms.int32)
  1269. net(a, b)
  1270. @pytest.mark.level0
  1271. @pytest.mark.platform_x86_cpu
  1272. @pytest.mark.env_onecard
  1273. def test_if_by_while_const_grad():
  1274. class MyNet(nn.Cell):
  1275. def __init__(self):
  1276. super().__init__()
  1277. self.add = P.Add()
  1278. def construct(self, *inputs):
  1279. out = self.add(*inputs)
  1280. return out
  1281. class GradNet(nn.Cell):
  1282. def __init__(self, net):
  1283. super(GradNet, self).__init__()
  1284. self.net = net
  1285. self.weights = ParameterTuple(net.trainable_params())
  1286. def construct(self, *inputs):
  1287. a = 1
  1288. b = 2
  1289. if a > 0:
  1290. b = 0
  1291. while a > 1:
  1292. a = a - 1
  1293. a += b
  1294. return grad_by_list(self.net, self.weights)(*inputs)
  1295. context.set_context(mode=context.GRAPH_MODE)
  1296. my_net = MyNet()
  1297. net = GradNet(my_net)
  1298. a = Tensor(np.array(0), dtype=ms.int32)
  1299. b = Tensor(np.array(1), dtype=ms.int32)
  1300. net(a, b)