You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cont_grad.py 51 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test control ops """
  16. import numpy as np
  17. import pytest
  18. from mindspore import dtype as ms
  19. from mindspore import Tensor
  20. from mindspore import context
  21. from mindspore import nn
  22. from mindspore.common.parameter import Parameter, ParameterTuple
  23. from mindspore.ops import composite as C
  24. from mindspore.ops import operations as P
  25. # from tests.vm_impl.math_ops_vm_impl import *
  26. # from tests.vm_impl.vm_interface import *
  27. # from tests.vm_impl import *
  28. # context.set_context(save_graphs=True)
  29. grad_by_list = C.GradOperation(get_by_list=True)
  30. grad_all = C.GradOperation(get_all=True)
  31. def test_while_forward():
  32. class MyWhileNet(nn.Cell):
  33. def __init__(self):
  34. super().__init__()
  35. self.max = P.ReduceMax()
  36. def construct(self, idx, end, x):
  37. while idx < end:
  38. part = x[idx, :, :]
  39. max_num = self.max(part)
  40. x[idx, :, 0:2] = max_num
  41. idx = idx + 1
  42. return x
  43. # graph mode
  44. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  45. net = MyWhileNet()
  46. idx = Tensor(np.array(0), dtype=ms.int32)
  47. end = Tensor(np.array(2), dtype=ms.int32)
  48. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  49. graph_output = net(idx, end, x)
  50. #pynative mode
  51. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  52. pynative_output = net(idx, end, x)
  53. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  54. def test_while_grad():
  55. class MyWhileNet(nn.Cell):
  56. def __init__(self):
  57. super().__init__()
  58. self.max = P.ReduceMax()
  59. def construct(self, idx, end, x):
  60. while idx < end:
  61. part = x[idx, :, :]
  62. max_num = self.max(part)
  63. x[idx, :, 0:2] = max_num
  64. idx = idx + 1
  65. return x
  66. class GradNet(nn.Cell):
  67. def __init__(self, net):
  68. super(GradNet, self).__init__()
  69. self.net = net
  70. def construct(self, *inputs):
  71. return grad_all(self.net)(*inputs)
  72. # graph mode
  73. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  74. while_net = MyWhileNet()
  75. net = GradNet(while_net)
  76. idx = Tensor(np.array(0), dtype=ms.int32)
  77. end = Tensor(np.array(2), dtype=ms.int32)
  78. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  79. graph_output = net(idx, end, x)
  80. # pynative mode
  81. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  82. pynative_output = net(idx, end, x)
  83. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  84. assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
  85. assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)
  86. def test_while_with_param_forward():
  87. class MyWhileNet(nn.Cell):
  88. def __init__(self):
  89. super().__init__()
  90. self.max = P.ReduceMax()
  91. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  92. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  93. def construct(self, idx, end, x):
  94. out = self.zero
  95. while idx < end:
  96. part = x[idx, :, :]
  97. max_num = self.max(part)
  98. x[idx, :, 0:2] = max_num
  99. out = out + x + self.param
  100. idx = idx + 1
  101. return out
  102. # graph mode
  103. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  104. net = MyWhileNet()
  105. idx = Tensor(np.array(0), dtype=ms.int32)
  106. end = Tensor(np.array(2), dtype=ms.int32)
  107. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  108. graph_output = net(idx, end, x)
  109. # pynative mode
  110. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  111. pynative_output = net(idx, end, x)
  112. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  113. def test_while_endless_case():
  114. """endless case when optimization"""
  115. class MyWhileNet(nn.Cell):
  116. def __init__(self):
  117. super().__init__()
  118. self.max = P.ReduceMax()
  119. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  120. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  121. def construct(self, idx, end, x):
  122. out = self.zero
  123. while idx < end:
  124. part = x[idx, :, :]
  125. out = out + part
  126. idx = idx + 1
  127. return out
  128. # graph mode
  129. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  130. net = MyWhileNet()
  131. idx = Tensor(np.array(0), dtype=ms.int32)
  132. end = Tensor(np.array(2), dtype=ms.int32)
  133. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  134. graph_output = net(idx, end, x)
  135. # pynative mode
  136. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  137. pynative_output = net(idx, end, x)
  138. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  139. def test_while_with_param_grad():
  140. class MyWhileNet(nn.Cell):
  141. def __init__(self):
  142. super().__init__()
  143. self.max = P.ReduceMax()
  144. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  145. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  146. def construct(self, idx, end, x):
  147. out = self.zero
  148. while idx < end:
  149. part = x[idx, :, :]
  150. max_num = self.max(part)
  151. x[idx, :, 0:2] = max_num
  152. out = out + x + self.param
  153. idx = idx + 1
  154. return out
  155. class GradNet(nn.Cell):
  156. def __init__(self, net):
  157. super(GradNet, self).__init__()
  158. self.net = net
  159. self.weights = ParameterTuple(net.trainable_params())
  160. def construct(self, a, b, c):
  161. return grad_by_list(self.net, self.weights)(a, b, c)
  162. # graph mode
  163. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  164. while_net = MyWhileNet()
  165. net = GradNet(while_net)
  166. idx = Tensor(np.array(0), dtype=ms.int32)
  167. end = Tensor(np.array(2), dtype=ms.int32)
  168. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  169. graph_output = net(idx, end, x)
  170. # pynative mode
  171. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  172. pynative_output = net(idx, end, x)
  173. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  174. def test_while_with_param_forward_with_const_branch():
  175. class MyWhileNet(nn.Cell):
  176. def __init__(self):
  177. super().__init__()
  178. self.max = P.ReduceMax()
  179. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  180. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  181. self.reduce = P.ReduceSum()
  182. def construct(self, idx, end, x):
  183. out = self.zero
  184. while idx < end:
  185. if 2 > 1:
  186. out = out + self.param
  187. else:
  188. out = out + idx + self.param
  189. idx = idx + 1
  190. return out
  191. # graph mode
  192. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  193. while_net = MyWhileNet()
  194. net = while_net
  195. idx = Tensor(np.array(0), dtype=ms.int32)
  196. end = Tensor(np.array(4), dtype=ms.int32)
  197. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  198. graph_output = net(idx, end, x)
  199. # pynative mode
  200. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  201. pynative_output = net(idx, end, x)
  202. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  203. def test_while_opt_endless():
  204. """endless during optimization case"""
  205. class MyWhileNet(nn.Cell):
  206. def __init__(self):
  207. super().__init__()
  208. self.max = P.ReduceMax()
  209. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  210. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  211. self.reduce = P.ReduceSum()
  212. self.addn = P.AddN()
  213. def construct(self, idx, end, x):
  214. addn1 = self.addn((x, x, x))
  215. out = addn1
  216. while idx < end:
  217. out = self.addn((out, addn1))
  218. idx = idx + 1
  219. out = self.addn((out, x))
  220. return out
  221. class GradNet(nn.Cell):
  222. def __init__(self, net):
  223. super(GradNet, self).__init__()
  224. self.net = net
  225. def construct(self, *inputs):
  226. return grad_all(self.net)(*inputs)
  227. # graph mode
  228. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  229. while_net = MyWhileNet()
  230. net = GradNet(while_net)
  231. idx = Tensor(np.array(0), dtype=ms.int32)
  232. end = Tensor(np.array(4), dtype=ms.int32)
  233. x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32)
  234. graph_output = net(idx, end, x)
  235. # pynative mode
  236. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  237. pynative_output = net(idx, end, x)
  238. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  239. def test_no_while_call():
  240. class MyWhileNet(nn.Cell):
  241. def __init__(self):
  242. super().__init__()
  243. self.max = P.ReduceMax()
  244. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  245. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  246. self.reduce = P.ReduceSum()
  247. def construct(self, idx, end, x):
  248. out = self.zero
  249. if 2 > 1:
  250. out = out + self.param
  251. else:
  252. out = out + idx + self.param
  253. return out
  254. # graph mode
  255. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  256. while_net = MyWhileNet()
  257. net = while_net
  258. idx = Tensor(np.array(0), dtype=ms.int32)
  259. end = Tensor(np.array(4), dtype=ms.int32)
  260. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  261. graph_output = net(idx, end, x)
  262. # pynative mode
  263. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  264. pynative_output = net(idx, end, x)
  265. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  266. def test_while_with_param_grad_with_const_branch():
  267. class MyWhileNet(nn.Cell):
  268. def __init__(self):
  269. super().__init__()
  270. self.max = P.ReduceMax()
  271. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  272. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  273. self.reduce = P.ReduceSum()
  274. def construct(self, idx, end, x):
  275. out = self.zero
  276. while idx < end:
  277. if 2 > 1:
  278. out = out + self.param
  279. else:
  280. out = out + idx + self.param
  281. idx = idx + 1
  282. return out
  283. class GradNet(nn.Cell):
  284. def __init__(self, net):
  285. super(GradNet, self).__init__()
  286. self.net = net
  287. self.weights = ParameterTuple(net.trainable_params())
  288. def construct(self, a, b, c):
  289. return grad_by_list(self.net, self.weights)(a, b, c)
  290. # graph mode
  291. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  292. while_net = MyWhileNet()
  293. net = GradNet(while_net)
  294. idx = Tensor(np.array(0), dtype=ms.int32)
  295. end = Tensor(np.array(4), dtype=ms.int32)
  296. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  297. graph_output = net(idx, end, x)
  298. # pynative mode
  299. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  300. pynative_output = net(idx, end, x)
  301. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  302. def test_for_while_with_param_grad_with_const_branch():
  303. class MyWhileNet(nn.Cell):
  304. def __init__(self):
  305. super().__init__()
  306. self.max = P.ReduceMax()
  307. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  308. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  309. self.reduce = P.ReduceSum()
  310. self.start = Tensor(np.array(0), dtype=ms.int32)
  311. def construct(self, idx, end, x):
  312. out = self.zero
  313. for _ in range(0, 2):
  314. idx = self.start
  315. while idx < end:
  316. if 2 > 1:
  317. out = out + self.param
  318. else:
  319. out = out + idx + self.param
  320. idx = idx + 1
  321. return out
  322. class GradNet(nn.Cell):
  323. def __init__(self, net):
  324. super(GradNet, self).__init__()
  325. self.net = net
  326. self.weights = ParameterTuple(net.trainable_params())
  327. def construct(self, a, b, c):
  328. return grad_by_list(self.net, self.weights)(a, b, c)
  329. # graph mode
  330. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  331. while_net = MyWhileNet()
  332. net = GradNet(while_net)
  333. idx = Tensor(np.array(0), dtype=ms.int32)
  334. end = Tensor(np.array(4), dtype=ms.int32)
  335. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  336. graph_output = net(idx, end, x)
  337. # pynative mode
  338. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  339. pynative_output = net(idx, end, x)
  340. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  341. def test_for_while_with_param_grad_basic():
  342. class MyWhileNet(nn.Cell):
  343. def __init__(self):
  344. super().__init__()
  345. self.max = P.ReduceMax()
  346. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  347. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  348. self.reduce = P.ReduceSum()
  349. self.start = Tensor(np.array(0), dtype=ms.int32)
  350. def construct(self, idx, end, x):
  351. out = self.zero
  352. for _ in range(0, 2):
  353. idx = self.start
  354. while idx < end:
  355. out = out + self.param
  356. idx = idx + 1
  357. return out
  358. class GradNet(nn.Cell):
  359. def __init__(self, net):
  360. super(GradNet, self).__init__()
  361. self.net = net
  362. self.weights = ParameterTuple(net.trainable_params())
  363. def construct(self, a, b, c):
  364. return grad_by_list(self.net, self.weights)(a, b, c)
  365. # graph mode
  366. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  367. while_net = MyWhileNet()
  368. net = GradNet(while_net)
  369. idx = Tensor(np.array(0), dtype=ms.int32)
  370. end = Tensor(np.array(4), dtype=ms.int32)
  371. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  372. graph_output = net(idx, end, x)
  373. # pynative mode
  374. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  375. pynative_output = net(idx, end, x)
  376. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  377. def test_for_while_with_param_grad_normal():
  378. class MyWhileNet(nn.Cell):
  379. def __init__(self):
  380. super().__init__()
  381. self.max = P.ReduceMax()
  382. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  383. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  384. self.reduce = P.ReduceSum()
  385. self.start = Tensor(np.array(0), dtype=ms.int32)
  386. def construct(self, idx, end, x):
  387. out = x
  388. for _ in range(0, 2):
  389. idx = self.start
  390. while idx < end:
  391. out = out + self.param
  392. idx = idx + 1
  393. return out
  394. class GradNet(nn.Cell):
  395. def __init__(self, net):
  396. super(GradNet, self).__init__()
  397. self.net = net
  398. self.weights = ParameterTuple(net.trainable_params())
  399. def construct(self, a, b, c):
  400. return grad_by_list(self.net, self.weights)(a, b, c)
  401. # graph mode
  402. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  403. while_net = MyWhileNet()
  404. net = GradNet(while_net)
  405. idx = Tensor(np.array(0), dtype=ms.int32)
  406. end = Tensor(np.array(4), dtype=ms.int32)
  407. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  408. graph_output = net(idx, end, x)
  409. # pynative mode
  410. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  411. pynative_output = net(idx, end, x)
  412. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  413. def test_while_with_param_basic_grad():
  414. class MyWhileNet(nn.Cell):
  415. def __init__(self):
  416. super().__init__()
  417. self.max = P.ReduceMax()
  418. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  419. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  420. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  421. def construct(self, idx, end, x):
  422. out = self.zero
  423. while idx < end:
  424. out = out + self.param
  425. idx = idx + 1
  426. return out + self.param
  427. class GradNet(nn.Cell):
  428. def __init__(self, net):
  429. super(GradNet, self).__init__()
  430. self.net = net
  431. self.weights = ParameterTuple(net.trainable_params())
  432. def construct(self, a, b, c):
  433. return grad_by_list(self.net, self.weights)(a, b, c)
  434. # graph mode
  435. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  436. while_net = MyWhileNet()
  437. net = GradNet(while_net)
  438. idx = Tensor(np.array(0), dtype=ms.int32)
  439. end = Tensor(np.array(3), dtype=ms.int32)
  440. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  441. graph_output = net(idx, end, x)
  442. # pynative mode
  443. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  444. pynative_output = net(idx, end, x)
  445. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  446. def test_while_with_param_basic_grad_mul():
  447. class MyWhileNet(nn.Cell):
  448. def __init__(self):
  449. super().__init__()
  450. self.max = P.ReduceMax()
  451. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  452. self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32)
  453. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  454. def construct(self, idx, end, x):
  455. out = self.zero
  456. while idx < end:
  457. out = out * self.param
  458. idx = idx + 1
  459. return out + self.param
  460. class GradNet(nn.Cell):
  461. def __init__(self, net):
  462. super(GradNet, self).__init__()
  463. self.net = net
  464. self.weights = ParameterTuple(net.trainable_params())
  465. def construct(self, a, b, c):
  466. return grad_by_list(self.net, self.weights)(a, b, c)
  467. # graph mode
  468. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  469. while_net = MyWhileNet()
  470. net = GradNet(while_net)
  471. idx = Tensor(np.array(0), dtype=ms.int32)
  472. end = Tensor(np.array(3), dtype=ms.int32)
  473. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  474. graph_output = net(idx, end, x)
  475. # pynative mode
  476. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  477. pynative_output = net(idx, end, x)
  478. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  479. def test_while_with_param_basic_grad_two():
  480. class MyWhileNet(nn.Cell):
  481. def __init__(self):
  482. super().__init__()
  483. self.max = P.ReduceMax()
  484. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  485. self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
  486. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  487. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  488. def construct(self, idx, end, x):
  489. out = self.zero
  490. while idx < end:
  491. out = out + self.param + self.weight
  492. idx = idx + 1
  493. return out + self.param
  494. class GradNet(nn.Cell):
  495. def __init__(self, net):
  496. super(GradNet, self).__init__()
  497. self.net = net
  498. self.weights = ParameterTuple(net.trainable_params())
  499. def construct(self, a, b, c):
  500. return grad_by_list(self.net, self.weights)(a, b, c)
  501. # graph mode
  502. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  503. while_net = MyWhileNet()
  504. net = GradNet(while_net)
  505. idx = Tensor(np.array(0), dtype=ms.int32)
  506. end = Tensor(np.array(3), dtype=ms.int32)
  507. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  508. graph_output = net(idx, end, x)
  509. # pynative mode
  510. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  511. pynative_output = net(idx, end, x)
  512. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  513. assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
  514. def test_while_with_param_basic_grad_three():
  515. class MyWhileNet(nn.Cell):
  516. def __init__(self):
  517. super().__init__()
  518. self.max = P.ReduceMax()
  519. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  520. self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
  521. self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key")
  522. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  523. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  524. def construct(self, idx, end, x):
  525. out = self.zero
  526. while idx < end:
  527. out = out + self.param + self.weight + self.key
  528. idx = idx + 1
  529. return out + self.param
  530. class GradNet(nn.Cell):
  531. def __init__(self, net):
  532. super(GradNet, self).__init__()
  533. self.net = net
  534. self.weights = ParameterTuple(net.trainable_params())
  535. def construct(self, a, b, c):
  536. return grad_by_list(self.net, self.weights)(a, b, c)
  537. # graph mode
  538. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  539. while_net = MyWhileNet()
  540. net = GradNet(while_net)
  541. idx = Tensor(np.array(0), dtype=ms.int32)
  542. end = Tensor(np.array(3), dtype=ms.int32)
  543. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  544. graph_output = net(idx, end, x)
  545. # pynative mode
  546. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  547. pynative_output = net(idx, end, x)
  548. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  549. assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
  550. assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)
  551. def test_while_if_with_param_grad():
  552. class MyWhileNet(nn.Cell):
  553. def __init__(self):
  554. super().__init__()
  555. self.max = P.ReduceMax()
  556. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  557. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  558. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  559. def construct(self, idx, end, x):
  560. out = self.zero
  561. while idx < end:
  562. if self.max(out) < self.max(x):
  563. out = out + self.param * 2
  564. else:
  565. out = out + self.param
  566. idx = idx + 1
  567. return out + self.param
  568. class GradNet(nn.Cell):
  569. def __init__(self, net):
  570. super(GradNet, self).__init__()
  571. self.net = net
  572. self.weights = ParameterTuple(net.trainable_params())
  573. def construct(self, a, b, c):
  574. return grad_by_list(self.net, self.weights)(a, b, c)
  575. # graph mode
  576. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  577. while_net = MyWhileNet()
  578. net = GradNet(while_net)
  579. idx = Tensor(np.array(0), dtype=ms.int32)
  580. end = Tensor(np.array(3), dtype=ms.int32)
  581. x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
  582. graph_output = net(idx, end, x)
  583. # pynative mode
  584. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  585. pynative_output = net(idx, end, x)
  586. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  587. def test_while_with_param_grad_not_enter_while():
  588. class MyWhileNet(nn.Cell):
  589. def __init__(self):
  590. super().__init__()
  591. self.max = P.ReduceMax()
  592. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  593. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  594. def construct(self, idx, end, x):
  595. out = self.zero
  596. while idx < end:
  597. out = out + self.param * 3
  598. idx = idx + 1
  599. return out + self.param
  600. class GradNet(nn.Cell):
  601. def __init__(self, net):
  602. super(GradNet, self).__init__()
  603. self.net = net
  604. self.weights = ParameterTuple(net.trainable_params())
  605. def construct(self, a, b, c):
  606. return grad_by_list(self.net, self.weights)(a, b, c)
  607. # graph mode
  608. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  609. while_net = MyWhileNet()
  610. net = GradNet(while_net)
  611. idx = Tensor(np.array(3), dtype=ms.int32)
  612. end = Tensor(np.array(0), dtype=ms.int32)
  613. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  614. graph_output = net(idx, end, x)
  615. # pynative mode
  616. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  617. pynative_output = net(idx, end, x)
  618. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  619. def test_with_param_if_by_if_forward():
  620. class MyIfByIfNet(nn.Cell):
  621. def __init__(self):
  622. super().__init__()
  623. self.max = P.ReduceMax()
  624. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  625. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  626. def construct(self, a, b, x):
  627. out = self.zero
  628. if a < b:
  629. out = out + x + self.param
  630. else:
  631. out = out + x
  632. if a == b:
  633. out = out + x*3 + self.param
  634. else:
  635. out = out + x*2
  636. return out
  637. # graph mode
  638. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  639. if_net = MyIfByIfNet()
  640. net = if_net
  641. idx = Tensor(np.array(0), dtype=ms.int32)
  642. end = Tensor(np.array(4), dtype=ms.int32)
  643. x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
  644. graph_output = net(idx, end, x)
  645. # pynative mode
  646. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  647. pynative_output = net(idx, end, x)
  648. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  649. def test_with_param_if_by_if_grad_inputs():
  650. class MyIfByIfNet(nn.Cell):
  651. def __init__(self):
  652. super().__init__()
  653. self.max = P.ReduceMax()
  654. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  655. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  656. def construct(self, a, b, x):
  657. out = self.zero
  658. if a < b:
  659. out = out + x + self.param * 4
  660. if a == b:
  661. out = out + x*3 + self.param * 3
  662. return out
  663. class GradNet(nn.Cell):
  664. def __init__(self, net):
  665. super(GradNet, self).__init__()
  666. self.net = net
  667. def construct(self, *inputs):
  668. return grad_all(self.net)(*inputs)
  669. # graph mode
  670. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  671. if_net = MyIfByIfNet()
  672. net = GradNet(if_net)
  673. idx = Tensor(np.array(0), dtype=ms.int32)
  674. end = Tensor(np.array(0), dtype=ms.int32)
  675. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  676. graph_output = net(idx, end, x)
  677. # pynative mode
  678. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  679. pynative_output = net(idx, end, x)
  680. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  681. assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001)
  682. assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001)
  683. def test_with_param_if_by_if_grad_parameter():
  684. class MyIfByIfNet(nn.Cell):
  685. def __init__(self):
  686. super().__init__()
  687. self.max = P.ReduceMax()
  688. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  689. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  690. def construct(self, a, b, x):
  691. out = self.zero
  692. if a < b:
  693. out = out + x + self.param * 2
  694. if a == b:
  695. out = out + x*3 + self.param
  696. return out
  697. class GradNet(nn.Cell):
  698. def __init__(self, net):
  699. super(GradNet, self).__init__()
  700. self.net = net
  701. self.weights = ParameterTuple(net.trainable_params())
  702. def construct(self, *inputs):
  703. return grad_by_list(self.net, self.weights)(*inputs)
  704. # graph mode
  705. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  706. if_net = MyIfByIfNet()
  707. net = GradNet(if_net)
  708. idx = Tensor(np.array(0), dtype=ms.int32)
  709. end = Tensor(np.array(2), dtype=ms.int32)
  710. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  711. graph_output = net(idx, end, x)
  712. # pynative mode
  713. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  714. pynative_output = net(idx, end, x)
  715. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  716. def test_with_param_if_by_if_grad_param_excute_null():
  717. class MyIfByIfNet(nn.Cell):
  718. def __init__(self):
  719. super().__init__()
  720. self.max = P.ReduceMax()
  721. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  722. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  723. def construct(self, a, b, x):
  724. out = self.zero
  725. if a < b:
  726. out = out + x + self.param * 2
  727. return out
  728. class GradNet(nn.Cell):
  729. def __init__(self, net):
  730. super(GradNet, self).__init__()
  731. self.net = net
  732. self.weights = ParameterTuple(net.trainable_params())
  733. def construct(self, *inputs):
  734. return grad_by_list(self.net, self.weights)(*inputs)
  735. # graph mode
  736. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  737. if_net = MyIfByIfNet()
  738. net = GradNet(if_net)
  739. idx = Tensor(np.array(4), dtype=ms.int32)
  740. end = Tensor(np.array(0), dtype=ms.int32)
  741. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  742. graph_output = net(idx, end, x)
  743. # pynative mode
  744. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  745. pynative_output = net(idx, end, x)
  746. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  747. def test_if_by_if_return_inside_grad():
  748. class MyIfByIfNet(nn.Cell):
  749. def __init__(self):
  750. super().__init__()
  751. self.max = P.ReduceMax()
  752. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  753. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  754. def construct(self, a, b, x):
  755. out = self.zero
  756. if a < b:
  757. return out + x + self.param
  758. if a == b:
  759. return out + self.param * 2
  760. return out + self.param * 3
  761. class GradNet(nn.Cell):
  762. def __init__(self, net):
  763. super(GradNet, self).__init__()
  764. self.net = net
  765. self.weights = ParameterTuple(net.trainable_params())
  766. def construct(self, *inputs):
  767. return grad_by_list(self.net, self.weights)(*inputs)
  768. # graph mode
  769. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  770. if_net = MyIfByIfNet()
  771. net = GradNet(if_net)
  772. idx = Tensor(np.array(1), dtype=ms.int32)
  773. end = Tensor(np.array(0), dtype=ms.int32)
  774. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  775. graph_output = net(idx, end, x)
  776. # pynative mode
  777. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  778. pynative_output = net(idx, end, x)
  779. assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001)
  780. def test_if_by_if_forward():
  781. class MyIfByIfNet(nn.Cell):
  782. def __init__(self):
  783. super().__init__()
  784. self.add = P.Add()
  785. self.sub = P.Sub()
  786. self.mul = P.Mul()
  787. self.div = P.RealDiv()
  788. def construct(self, a, b, x):
  789. if a < b:
  790. a = self.add(a, b)
  791. else:
  792. a = self.sub(a, b)
  793. if a == x:
  794. a = self.mul(a, b)
  795. else:
  796. a = self.div(a, b)
  797. if b == x:
  798. b = self.add(a, b)
  799. else:
  800. b = self.add(a, x)
  801. a = a * b
  802. out = a + b + x
  803. return out
  804. # graph mode
  805. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  806. if_net = MyIfByIfNet()
  807. net = if_net
  808. idx = Tensor(np.array(2), dtype=ms.float32)
  809. end = Tensor(np.array(3), dtype=ms.float32)
  810. x = Tensor(np.array(4), dtype=ms.float32)
  811. graph_output = net(idx, end, x)
  812. # pynative mode
  813. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  814. pynative_output = net(idx, end, x)
  815. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  816. def test_if_by_if_forward_control_tuple_switch():
  817. """tuple_get from switch op will generate new switch inside to eliminate tuple_get"""
  818. class Branch3Net(nn.Cell):
  819. def __init__(self):
  820. super().__init__()
  821. self.add = P.Add()
  822. self.sub = P.Sub()
  823. self.mul = P.Mul()
  824. self.div = P.RealDiv()
  825. def construct(self, a, b, x):
  826. if b == x:
  827. b = self.add(a, b)
  828. else:
  829. b = self.add(a, x)
  830. return a, b, x
  831. class Branch2Net(nn.Cell):
  832. def __init__(self):
  833. super().__init__()
  834. self.add = P.Add()
  835. self.sub = P.Sub()
  836. self.mul = P.Mul()
  837. self.div = P.RealDiv()
  838. self.net = Branch3Net()
  839. def construct(self, a, b, x):
  840. if a == x:
  841. a = self.mul(a, b)
  842. else:
  843. a = self.div(a, b)
  844. return self.net(a, b, x)
  845. class MyIfByIfNet(nn.Cell):
  846. def __init__(self):
  847. super().__init__()
  848. self.add = P.Add()
  849. self.sub = P.Sub()
  850. self.mul = P.Mul()
  851. self.div = P.RealDiv()
  852. self.net = Branch2Net()
  853. def construct(self, a, b, x):
  854. if a < b:
  855. a = self.add(a, b)
  856. else:
  857. a = self.sub(a, b)
  858. a, b, x = self.net(a, b, x)
  859. a = a * b
  860. out = a + b + x
  861. return out
  862. # graph mode
  863. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  864. if_net = MyIfByIfNet()
  865. net = if_net
  866. idx = Tensor(np.array(2), dtype=ms.float32)
  867. end = Tensor(np.array(3), dtype=ms.float32)
  868. x = Tensor(np.array(0), dtype=ms.float32)
  869. graph_output = net(idx, end, x)
  870. # pynative mode
  871. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  872. pynative_output = net(idx, end, x)
  873. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  874. def test_if_by_if_forward_control_inside_net():
  875. class Branch3Net(nn.Cell):
  876. def __init__(self):
  877. super().__init__()
  878. self.add = P.Add()
  879. self.sub = P.Sub()
  880. self.mul = P.Mul()
  881. self.div = P.RealDiv()
  882. def construct(self, a, b, x):
  883. if b == x:
  884. b = self.add(a, b)
  885. else:
  886. b = self.add(a, x)
  887. a = a * b
  888. out = a + b + x
  889. return out
  890. class Branch2Net(nn.Cell):
  891. def __init__(self):
  892. super().__init__()
  893. self.add = P.Add()
  894. self.sub = P.Sub()
  895. self.mul = P.Mul()
  896. self.div = P.RealDiv()
  897. self.net = Branch3Net()
  898. def construct(self, a, b, x):
  899. if a == x:
  900. a = self.mul(a, b)
  901. else:
  902. a = self.div(a, b)
  903. return self.net(a, b, x)
  904. class MyIfByIfNet(nn.Cell):
  905. def __init__(self):
  906. super().__init__()
  907. self.add = P.Add()
  908. self.sub = P.Sub()
  909. self.mul = P.Mul()
  910. self.div = P.RealDiv()
  911. self.net = Branch2Net()
  912. def construct(self, a, b, x):
  913. if a < b:
  914. a = self.add(a, b)
  915. else:
  916. a = self.sub(a, b)
  917. out = self.net(a, b, x)
  918. return out
  919. # graph mode
  920. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  921. if_net = MyIfByIfNet()
  922. net = if_net
  923. idx = Tensor(np.array(2), dtype=ms.float32)
  924. end = Tensor(np.array(3), dtype=ms.float32)
  925. x = Tensor(np.array(0), dtype=ms.float32)
  926. graph_output = net(idx, end, x)
  927. # pynative mode
  928. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  929. pynative_output = net(idx, end, x)
  930. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  931. def test_if_by_if_forward_use_namespace():
  932. class MyIfByIfNet(nn.Cell):
  933. def __init__(self):
  934. super().__init__()
  935. self.add = P.Add()
  936. self.sub = P.Sub()
  937. self.mul = P.Mul()
  938. self.div = P.RealDiv()
  939. def construct(self, a, b, x):
  940. if a < b:
  941. a = P.Add()(a, b)
  942. else:
  943. a = P.Sub()(a, b)
  944. if a == x:
  945. a = P.Mul()(a, b)
  946. else:
  947. a = P.RealDiv()(a, b)
  948. if b == x:
  949. b = P.Add()(a, b)
  950. else:
  951. b = P.Add()(a, x)
  952. a = a * b
  953. out = a + b + x
  954. return out
  955. # graph mode
  956. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  957. if_net = MyIfByIfNet()
  958. net = if_net
  959. idx = Tensor(np.array(2), dtype=ms.float32)
  960. end = Tensor(np.array(3), dtype=ms.float32)
  961. x = Tensor(np.array(0), dtype=ms.float32)
  962. graph_output = net(idx, end, x)
  963. # pynative mode
  964. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  965. pynative_output = net(idx, end, x)
  966. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  967. def test_if_by_if_forward_use_global_op():
  968. class MyIfByIfNet(nn.Cell):
  969. def __init__(self):
  970. super().__init__()
  971. self.add = P.Add()
  972. self.sub = P.Sub()
  973. self.mul = P.Mul()
  974. self.div = P.RealDiv()
  975. def construct(self, a, b, x):
  976. add = P.Add()
  977. sub = P.Sub()
  978. mul = P.Mul()
  979. div = P.RealDiv()
  980. if a < b:
  981. a = add(a, b)
  982. else:
  983. a = sub(a, b)
  984. if a == x:
  985. a = mul(a, b)
  986. else:
  987. a = div(a, b)
  988. if b == x:
  989. b = add(a, b)
  990. else:
  991. b = add(a, x)
  992. a = a * b
  993. out = a + b + x
  994. return out
  995. # graph mode
  996. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  997. if_net = MyIfByIfNet()
  998. net = if_net
  999. idx = Tensor(np.array(2), dtype=ms.float32)
  1000. end = Tensor(np.array(3), dtype=ms.float32)
  1001. x = Tensor(np.array(0), dtype=ms.float32)
  1002. graph_output = net(idx, end, x)
  1003. # pynative mode
  1004. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1005. pynative_output = net(idx, end, x)
  1006. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1007. def test_for_with_if_by_if_forward():
  1008. class MyIfByIfNet(nn.Cell):
  1009. def __init__(self):
  1010. super().__init__()
  1011. self.add = P.Add()
  1012. self.sub = P.Sub()
  1013. def construct(self, a, b, x):
  1014. for _ in range(0, 4):
  1015. if a < b:
  1016. a = self.add(a, b)
  1017. else:
  1018. b = self.sub(b, x)
  1019. a = a * b
  1020. out = a + b + x
  1021. return out
  1022. # graph mode
  1023. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  1024. if_net = MyIfByIfNet()
  1025. net = if_net
  1026. idx = Tensor(np.array(2), dtype=ms.float32)
  1027. end = Tensor(np.array(3), dtype=ms.float32)
  1028. x = Tensor(np.array(0), dtype=ms.float32)
  1029. graph_output = net(idx, end, x)
  1030. # pynative mode
  1031. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1032. pynative_output = net(idx, end, x)
  1033. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1034. def test_for_with_if_by_if_forward_namespace():
  1035. class MyIfByIfNet(nn.Cell):
  1036. def __init__(self):
  1037. super().__init__()
  1038. self.add = P.Add()
  1039. self.sub = P.Sub()
  1040. self.mul = P.Mul()
  1041. self.div = P.RealDiv()
  1042. def construct(self, a, b, x):
  1043. for _ in range(0, 6):
  1044. if a < b:
  1045. a = P.Add()(a, b)
  1046. else:
  1047. b = P.Sub()(b, x)
  1048. a = a * b
  1049. out = a + b + x
  1050. return out
  1051. # graph mode
  1052. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  1053. if_net = MyIfByIfNet()
  1054. net = if_net
  1055. idx = Tensor(np.array(2), dtype=ms.float32)
  1056. end = Tensor(np.array(3), dtype=ms.float32)
  1057. x = Tensor(np.array(0), dtype=ms.float32)
  1058. graph_output = net(idx, end, x)
  1059. # pynative mode
  1060. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1061. pynative_output = net(idx, end, x)
  1062. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1063. def test_if_by_if_forward_const_branch_inner():
  1064. class MyIfByIfNet(nn.Cell):
  1065. def __init__(self):
  1066. super().__init__()
  1067. self.add = P.Add()
  1068. self.sub = P.Sub()
  1069. self.mul = P.Mul()
  1070. self.div = P.RealDiv()
  1071. def construct(self, a, b, x):
  1072. add = P.Add()
  1073. sub = P.Sub()
  1074. mul = P.Mul()
  1075. div = P.RealDiv()
  1076. if a < b:
  1077. a = add(a, b)
  1078. else:
  1079. a = sub(a, b)
  1080. if 2 > 1:
  1081. a = mul(a, b)
  1082. else:
  1083. a = div(a, b)
  1084. if b == x:
  1085. b = add(a, b)
  1086. else:
  1087. b = add(a, x)
  1088. a = a * b
  1089. out = a + b + x
  1090. return out
  1091. # graph mode
  1092. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  1093. if_net = MyIfByIfNet()
  1094. net = if_net
  1095. idx = Tensor(np.array(2), dtype=ms.float32)
  1096. end = Tensor(np.array(3), dtype=ms.float32)
  1097. x = Tensor(np.array(0), dtype=ms.float32)
  1098. graph_output = net(idx, end, x)
  1099. # pynative mode
  1100. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1101. pynative_output = net(idx, end, x)
  1102. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1103. def test_if_by_if_forward_all_const_branch():
  1104. class MyIfByIfNet(nn.Cell):
  1105. def __init__(self):
  1106. super().__init__()
  1107. self.add = P.Add()
  1108. self.sub = P.Sub()
  1109. self.mul = P.Mul()
  1110. self.div = P.RealDiv()
  1111. def construct(self, a, b, x):
  1112. add = P.Add()
  1113. sub = P.Sub()
  1114. mul = P.Mul()
  1115. div = P.RealDiv()
  1116. if 2 < 12:
  1117. a = add(a, b)
  1118. else:
  1119. a = sub(a, b)
  1120. if 2 > 1:
  1121. a = mul(a, b)
  1122. else:
  1123. a = div(a, b)
  1124. if 2 == 1:
  1125. b = add(a, b)
  1126. else:
  1127. b = add(a, x)
  1128. a = a * b
  1129. out = a + b + x
  1130. return out
  1131. # graph mode
  1132. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  1133. if_net = MyIfByIfNet()
  1134. net = if_net
  1135. idx = Tensor(np.array(2), dtype=ms.float32)
  1136. end = Tensor(np.array(3), dtype=ms.float32)
  1137. x = Tensor(np.array(0), dtype=ms.float32)
  1138. graph_output = net(idx, end, x)
  1139. # pynative mode
  1140. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  1141. pynative_output = net(idx, end, x)
  1142. assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001)
  1143. @pytest.mark.level0
  1144. @pytest.mark.platform_x86_cpu
  1145. @pytest.mark.env_onecard
  1146. def test_if_const_grad():
  1147. class MyNet(nn.Cell):
  1148. def __init__(self):
  1149. super().__init__()
  1150. self.add = P.Add()
  1151. def construct(self, *inputs):
  1152. out = self.add(*inputs)
  1153. return out
  1154. class GradNet(nn.Cell):
  1155. def __init__(self, net):
  1156. super(GradNet, self).__init__()
  1157. self.net = net
  1158. self.weights = ParameterTuple(net.trainable_params())
  1159. def construct(self, *inputs):
  1160. a = 1
  1161. b = 2
  1162. if a > 0:
  1163. b = 1
  1164. a += b
  1165. return grad_by_list(self.net, self.weights)(*inputs)
  1166. context.set_context(mode=context.GRAPH_MODE)
  1167. my_net = MyNet()
  1168. net = GradNet(my_net)
  1169. a = Tensor(np.array(0), dtype=ms.int32)
  1170. b = Tensor(np.array(1), dtype=ms.int32)
  1171. net(a, b)
  1172. @pytest.mark.level0
  1173. @pytest.mark.platform_x86_cpu
  1174. @pytest.mark.env_onecard
  1175. def test_if_by_if_const_grad():
  1176. class MyNet(nn.Cell):
  1177. def __init__(self):
  1178. super().__init__()
  1179. self.add = P.Add()
  1180. def construct(self, *inputs):
  1181. out = self.add(*inputs)
  1182. return out
  1183. class GradNet(nn.Cell):
  1184. def __init__(self, net):
  1185. super(GradNet, self).__init__()
  1186. self.net = net
  1187. self.weights = ParameterTuple(net.trainable_params())
  1188. def construct(self, *inputs):
  1189. a = 1
  1190. b = 2
  1191. if a > 0:
  1192. b = 1
  1193. if a < 0:
  1194. b = 0
  1195. if a == 0:
  1196. b = 3
  1197. a += b
  1198. return grad_by_list(self.net, self.weights)(*inputs)
  1199. context.set_context(mode=context.GRAPH_MODE)
  1200. my_net = MyNet()
  1201. net = GradNet(my_net)
  1202. a = Tensor(np.array(0), dtype=ms.int32)
  1203. b = Tensor(np.array(1), dtype=ms.int32)
  1204. net(a, b)
  1205. @pytest.mark.level0
  1206. @pytest.mark.platform_x86_cpu
  1207. @pytest.mark.env_onecard
  1208. def test_while_const_grad():
  1209. class MyNet(nn.Cell):
  1210. def __init__(self):
  1211. super().__init__()
  1212. self.add = P.Add()
  1213. def construct(self, *inputs):
  1214. out = self.add(*inputs)
  1215. return out
  1216. class GradNet(nn.Cell):
  1217. def __init__(self, net):
  1218. super(GradNet, self).__init__()
  1219. self.net = net
  1220. self.weights = ParameterTuple(net.trainable_params())
  1221. def construct(self, *inputs):
  1222. a = 1
  1223. while a > 1:
  1224. a = a - 1
  1225. return grad_by_list(self.net, self.weights)(*inputs)
  1226. context.set_context(mode=context.GRAPH_MODE)
  1227. my_net = MyNet()
  1228. net = GradNet(my_net)
  1229. a = Tensor(np.array(0), dtype=ms.int32)
  1230. b = Tensor(np.array(1), dtype=ms.int32)
  1231. net(a, b)
  1232. @pytest.mark.level0
  1233. @pytest.mark.platform_x86_cpu
  1234. @pytest.mark.env_onecard
  1235. def test_if_by_while_const_grad():
  1236. class MyNet(nn.Cell):
  1237. def __init__(self):
  1238. super().__init__()
  1239. self.add = P.Add()
  1240. def construct(self, *inputs):
  1241. out = self.add(*inputs)
  1242. return out
  1243. class GradNet(nn.Cell):
  1244. def __init__(self, net):
  1245. super(GradNet, self).__init__()
  1246. self.net = net
  1247. self.weights = ParameterTuple(net.trainable_params())
  1248. def construct(self, *inputs):
  1249. a = 1
  1250. b = 2
  1251. if a > 0:
  1252. b = 0
  1253. while a > 1:
  1254. a = a - 1
  1255. a += b
  1256. return grad_by_list(self.net, self.weights)(*inputs)
  1257. context.set_context(mode=context.GRAPH_MODE)
  1258. my_net = MyNet()
  1259. net = GradNet(my_net)
  1260. a = Tensor(np.array(0), dtype=ms.int32)
  1261. b = Tensor(np.array(1), dtype=ms.int32)
  1262. net(a, b)