You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cont_cases.py 32 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test control ops """
  16. import pytest
  17. import numpy as np
  18. from mindspore import dtype as ms
  19. from mindspore import Tensor
  20. from mindspore import context
  21. from mindspore import nn
  22. from mindspore import ms_function
  23. from mindspore.common.parameter import Parameter, ParameterTuple
  24. from mindspore.ops import composite as C
  25. from mindspore.ops import operations as P
  26. # from tests.vm_impl.math_ops_vm_impl import *
  27. # from tests.vm_impl.vm_interface import *
  28. # from tests.vm_impl import *
  29. grad_by_list = C.GradOperation(get_by_list=True)
  30. grad_all = C.GradOperation(get_all=True)
  31. @pytest.fixture(scope="module", autouse=True)
  32. def setup_teardown():
  33. context.set_context(mode=context.PYNATIVE_MODE, precompile_only=True)
  34. yield
  35. context.set_context(mode=context.GRAPH_MODE, precompile_only=False)
  36. def test_while_with_param_forward_with_const_branch():
  37. class MyWhileNet(nn.Cell):
  38. def __init__(self):
  39. super().__init__()
  40. self.max = P.ReduceMax()
  41. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  42. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  43. self.reduce = P.ReduceSum()
  44. @ms_function
  45. def construct(self, idx, end, x):
  46. out = self.zero
  47. while idx < end:
  48. if 2 > 1:
  49. out = out + self.param
  50. else:
  51. out = out + idx + self.param
  52. idx = idx + 1
  53. return out
  54. while_net = MyWhileNet()
  55. net = while_net
  56. idx = Tensor(np.array(0), dtype=ms.int32)
  57. end = Tensor(np.array(4), dtype=ms.int32)
  58. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  59. net(idx, end, x)
  60. def test_while_opt_endless():
  61. """endless during optimization case"""
  62. class MyWhileNet(nn.Cell):
  63. def __init__(self):
  64. super().__init__()
  65. self.max = P.ReduceMax()
  66. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  67. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  68. self.reduce = P.ReduceSum()
  69. self.addn = P.AddN()
  70. def construct(self, idx, end, x):
  71. addn1 = self.addn((x, x, x))
  72. out = addn1
  73. while idx < end:
  74. out = self.addn((out, addn1))
  75. idx = idx + 1
  76. out = self.addn((out, x))
  77. return out
  78. class GradNet(nn.Cell):
  79. def __init__(self, net):
  80. super(GradNet, self).__init__()
  81. self.net = net
  82. @ms_function
  83. def construct(self, *inputs):
  84. return grad_all(self.net)(*inputs)
  85. while_net = MyWhileNet()
  86. net = GradNet(while_net)
  87. idx = Tensor(np.array(0), dtype=ms.int32)
  88. end = Tensor(np.array(4), dtype=ms.int32)
  89. x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32)
  90. net(idx, end, x)
  91. def test_no_while_call():
  92. class MyWhileNet(nn.Cell):
  93. def __init__(self):
  94. super().__init__()
  95. self.max = P.ReduceMax()
  96. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  97. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  98. self.reduce = P.ReduceSum()
  99. @ms_function
  100. def construct(self, idx, end, x):
  101. out = self.zero
  102. if 2 > 1:
  103. out = out + self.param
  104. else:
  105. out = out + idx + self.param
  106. return out
  107. while_net = MyWhileNet()
  108. net = while_net
  109. idx = Tensor(np.array(0), dtype=ms.int32)
  110. end = Tensor(np.array(4), dtype=ms.int32)
  111. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  112. net(idx, end, x)
  113. def test_while_with_param_grad_with_const_branch():
  114. class MyWhileNet(nn.Cell):
  115. def __init__(self):
  116. super().__init__()
  117. self.max = P.ReduceMax()
  118. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  119. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  120. self.reduce = P.ReduceSum()
  121. def construct(self, idx, end, x):
  122. out = self.zero
  123. while idx < end:
  124. if 2 > 1:
  125. out = out + self.param
  126. else:
  127. out = out + idx + self.param
  128. idx = idx + 1
  129. return out
  130. class GradNet(nn.Cell):
  131. def __init__(self, net):
  132. super(GradNet, self).__init__()
  133. self.net = net
  134. self.weights = ParameterTuple(net.trainable_params())
  135. @ms_function
  136. def construct(self, a, b, c):
  137. return grad_by_list(self.net, self.weights)(a, b, c)
  138. while_net = MyWhileNet()
  139. net = GradNet(while_net)
  140. idx = Tensor(np.array(0), dtype=ms.int32)
  141. end = Tensor(np.array(4), dtype=ms.int32)
  142. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  143. net(idx, end, x)
  144. def test_for_while_with_param_grad_with_const_branch():
  145. class MyWhileNet(nn.Cell):
  146. def __init__(self):
  147. super().__init__()
  148. self.max = P.ReduceMax()
  149. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  150. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  151. self.reduce = P.ReduceSum()
  152. self.start = Tensor(np.array(0), dtype=ms.int32)
  153. def construct(self, idx, end, x):
  154. out = self.zero
  155. for _ in range(0, 2):
  156. idx = self.start
  157. while idx < end:
  158. if 2 > 1:
  159. out = out + self.param
  160. else:
  161. out = out + idx + self.param
  162. idx = idx + 1
  163. return out
  164. class GradNet(nn.Cell):
  165. def __init__(self, net):
  166. super(GradNet, self).__init__()
  167. self.net = net
  168. self.weights = ParameterTuple(net.trainable_params())
  169. @ms_function
  170. def construct(self, a, b, c):
  171. return grad_by_list(self.net, self.weights)(a, b, c)
  172. while_net = MyWhileNet()
  173. net = GradNet(while_net)
  174. idx = Tensor(np.array(0), dtype=ms.int32)
  175. end = Tensor(np.array(4), dtype=ms.int32)
  176. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  177. net(idx, end, x)
  178. def test_for_while_with_param_grad_basic():
  179. class MyWhileNet(nn.Cell):
  180. def __init__(self):
  181. super().__init__()
  182. self.max = P.ReduceMax()
  183. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  184. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  185. self.reduce = P.ReduceSum()
  186. self.start = Tensor(np.array(0), dtype=ms.int32)
  187. def construct(self, idx, end, x):
  188. out = self.zero
  189. for _ in range(0, 2):
  190. idx = self.start
  191. while idx < end:
  192. out = out + self.param
  193. idx = idx + 1
  194. return out
  195. class GradNet(nn.Cell):
  196. def __init__(self, net):
  197. super(GradNet, self).__init__()
  198. self.net = net
  199. self.weights = ParameterTuple(net.trainable_params())
  200. @ms_function
  201. def construct(self, a, b, c):
  202. return grad_by_list(self.net, self.weights)(a, b, c)
  203. while_net = MyWhileNet()
  204. net = GradNet(while_net)
  205. idx = Tensor(np.array(0), dtype=ms.int32)
  206. end = Tensor(np.array(4), dtype=ms.int32)
  207. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  208. net(idx, end, x)
  209. def test_for_while_with_param_grad_normal():
  210. class MyWhileNet(nn.Cell):
  211. def __init__(self):
  212. super().__init__()
  213. self.max = P.ReduceMax()
  214. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  215. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  216. self.reduce = P.ReduceSum()
  217. self.start = Tensor(np.array(0), dtype=ms.int32)
  218. def construct(self, idx, end, x):
  219. out = x
  220. for _ in range(0, 2):
  221. idx = self.start
  222. while idx < end:
  223. out = out + self.param
  224. idx = idx + 1
  225. return out
  226. class GradNet(nn.Cell):
  227. def __init__(self, net):
  228. super(GradNet, self).__init__()
  229. self.net = net
  230. self.weights = ParameterTuple(net.trainable_params())
  231. @ms_function
  232. def construct(self, a, b, c):
  233. return grad_by_list(self.net, self.weights)(a, b, c)
  234. while_net = MyWhileNet()
  235. net = GradNet(while_net)
  236. idx = Tensor(np.array(0), dtype=ms.int32)
  237. end = Tensor(np.array(4), dtype=ms.int32)
  238. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  239. net(idx, end, x)
  240. def test_while_with_param_basic_grad():
  241. class MyWhileNet(nn.Cell):
  242. def __init__(self):
  243. super().__init__()
  244. self.max = P.ReduceMax()
  245. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  246. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  247. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  248. def construct(self, idx, end, x):
  249. out = self.zero
  250. while idx < end:
  251. out = out + self.param
  252. idx = idx + 1
  253. return out + self.param
  254. class GradNet(nn.Cell):
  255. def __init__(self, net):
  256. super(GradNet, self).__init__()
  257. self.net = net
  258. self.weights = ParameterTuple(net.trainable_params())
  259. @ms_function
  260. def construct(self, a, b, c):
  261. return grad_by_list(self.net, self.weights)(a, b, c)
  262. while_net = MyWhileNet()
  263. net = GradNet(while_net)
  264. idx = Tensor(np.array(0), dtype=ms.int32)
  265. end = Tensor(np.array(3), dtype=ms.int32)
  266. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  267. net(idx, end, x)
  268. def test_while_with_param_basic_grad_mul():
  269. class MyWhileNet(nn.Cell):
  270. def __init__(self):
  271. super().__init__()
  272. self.max = P.ReduceMax()
  273. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  274. self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32)
  275. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  276. def construct(self, idx, end, x):
  277. out = self.zero
  278. while idx < end:
  279. out = out * self.param
  280. idx = idx + 1
  281. return out + self.param
  282. class GradNet(nn.Cell):
  283. def __init__(self, net):
  284. super(GradNet, self).__init__()
  285. self.net = net
  286. self.weights = ParameterTuple(net.trainable_params())
  287. @ms_function
  288. def construct(self, a, b, c):
  289. return grad_by_list(self.net, self.weights)(a, b, c)
  290. while_net = MyWhileNet()
  291. net = GradNet(while_net)
  292. idx = Tensor(np.array(0), dtype=ms.int32)
  293. end = Tensor(np.array(3), dtype=ms.int32)
  294. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  295. net(idx, end, x)
  296. def test_while_with_param_basic_grad_two():
  297. class MyWhileNet(nn.Cell):
  298. def __init__(self):
  299. super().__init__()
  300. self.max = P.ReduceMax()
  301. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  302. self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
  303. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  304. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  305. def construct(self, idx, end, x):
  306. out = self.zero
  307. while idx < end:
  308. out = out + self.param + self.weight
  309. idx = idx + 1
  310. return out + self.param
  311. class GradNet(nn.Cell):
  312. def __init__(self, net):
  313. super(GradNet, self).__init__()
  314. self.net = net
  315. self.weights = ParameterTuple(net.trainable_params())
  316. @ms_function
  317. def construct(self, a, b, c):
  318. return grad_by_list(self.net, self.weights)(a, b, c)
  319. while_net = MyWhileNet()
  320. net = GradNet(while_net)
  321. idx = Tensor(np.array(0), dtype=ms.int32)
  322. end = Tensor(np.array(3), dtype=ms.int32)
  323. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  324. net(idx, end, x)
  325. def test_while_with_param_basic_grad_three():
  326. class MyWhileNet(nn.Cell):
  327. def __init__(self):
  328. super().__init__()
  329. self.max = P.ReduceMax()
  330. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  331. self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
  332. self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key")
  333. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  334. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  335. def construct(self, idx, end, x):
  336. out = self.zero
  337. while idx < end:
  338. out = out + self.param + self.weight + self.key
  339. idx = idx + 1
  340. return out + self.param
  341. class GradNet(nn.Cell):
  342. def __init__(self, net):
  343. super(GradNet, self).__init__()
  344. self.net = net
  345. self.weights = ParameterTuple(net.trainable_params())
  346. @ms_function
  347. def construct(self, a, b, c):
  348. return grad_by_list(self.net, self.weights)(a, b, c)
  349. while_net = MyWhileNet()
  350. net = GradNet(while_net)
  351. idx = Tensor(np.array(0), dtype=ms.int32)
  352. end = Tensor(np.array(3), dtype=ms.int32)
  353. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  354. net(idx, end, x)
  355. def test_while_if_with_param_grad():
  356. class MyWhileNet(nn.Cell):
  357. def __init__(self):
  358. super().__init__()
  359. self.max = P.ReduceMax()
  360. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  361. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  362. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  363. def construct(self, idx, end, x):
  364. out = self.zero
  365. while idx < end:
  366. if self.max(out) < self.max(x):
  367. out = out + self.param * 2
  368. else:
  369. out = out + self.param
  370. idx = idx + 1
  371. return out + self.param
  372. class GradNet(nn.Cell):
  373. def __init__(self, net):
  374. super(GradNet, self).__init__()
  375. self.net = net
  376. self.weights = ParameterTuple(net.trainable_params())
  377. @ms_function
  378. def construct(self, a, b, c):
  379. return grad_by_list(self.net, self.weights)(a, b, c)
  380. while_net = MyWhileNet()
  381. net = GradNet(while_net)
  382. idx = Tensor(np.array(0), dtype=ms.int32)
  383. end = Tensor(np.array(3), dtype=ms.int32)
  384. x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
  385. net(idx, end, x)
  386. def test_while_with_param_grad_not_enter_while():
  387. class MyWhileNet(nn.Cell):
  388. def __init__(self):
  389. super().__init__()
  390. self.max = P.ReduceMax()
  391. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  392. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  393. def construct(self, idx, end, x):
  394. out = self.zero
  395. while idx < end:
  396. out = out + self.param * 3
  397. idx = idx + 1
  398. return out + self.param
  399. class GradNet(nn.Cell):
  400. def __init__(self, net):
  401. super(GradNet, self).__init__()
  402. self.net = net
  403. self.weights = ParameterTuple(net.trainable_params())
  404. @ms_function
  405. def construct(self, a, b, c):
  406. return grad_by_list(self.net, self.weights)(a, b, c)
  407. while_net = MyWhileNet()
  408. net = GradNet(while_net)
  409. idx = Tensor(np.array(3), dtype=ms.int32)
  410. end = Tensor(np.array(0), dtype=ms.int32)
  411. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  412. net(idx, end, x)
  413. def test_with_param_if_by_if_forward():
  414. class MyIfByIfNet(nn.Cell):
  415. def __init__(self):
  416. super().__init__()
  417. self.max = P.ReduceMax()
  418. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  419. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  420. @ms_function
  421. def construct(self, a, b, x):
  422. out = self.zero
  423. if a < b:
  424. out = out + x + self.param
  425. else:
  426. out = out + x
  427. if a == b:
  428. out = out + x*3 + self.param
  429. else:
  430. out = out + x*2
  431. return out
  432. if_net = MyIfByIfNet()
  433. net = if_net
  434. idx = Tensor(np.array(0), dtype=ms.int32)
  435. end = Tensor(np.array(4), dtype=ms.int32)
  436. x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
  437. net(idx, end, x)
  438. def test_with_param_if_by_if_grad_inputs():
  439. class MyIfByIfNet(nn.Cell):
  440. def __init__(self):
  441. super().__init__()
  442. self.max = P.ReduceMax()
  443. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  444. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  445. def construct(self, a, b, x):
  446. out = self.zero
  447. if a < b:
  448. out = out + x + self.param * 4
  449. if a == b:
  450. out = out + x*3 + self.param * 3
  451. return out
  452. class GradNet(nn.Cell):
  453. def __init__(self, net):
  454. super(GradNet, self).__init__()
  455. self.net = net
  456. @ms_function
  457. def construct(self, *inputs):
  458. return grad_all(self.net)(*inputs)
  459. if_net = MyIfByIfNet()
  460. net = GradNet(if_net)
  461. idx = Tensor(np.array(0), dtype=ms.int32)
  462. end = Tensor(np.array(0), dtype=ms.int32)
  463. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  464. net(idx, end, x)
  465. def test_with_param_if_by_if_grad_parameter():
  466. class MyIfByIfNet(nn.Cell):
  467. def __init__(self):
  468. super().__init__()
  469. self.max = P.ReduceMax()
  470. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  471. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  472. def construct(self, a, b, x):
  473. out = self.zero
  474. if a < b:
  475. out = out + x + self.param * 2
  476. if a == b:
  477. out = out + x*3 + self.param
  478. return out
  479. class GradNet(nn.Cell):
  480. def __init__(self, net):
  481. super(GradNet, self).__init__()
  482. self.net = net
  483. self.weights = ParameterTuple(net.trainable_params())
  484. @ms_function
  485. def construct(self, *inputs):
  486. return grad_by_list(self.net, self.weights)(*inputs)
  487. if_net = MyIfByIfNet()
  488. net = GradNet(if_net)
  489. idx = Tensor(np.array(0), dtype=ms.int32)
  490. end = Tensor(np.array(2), dtype=ms.int32)
  491. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  492. net(idx, end, x)
  493. def test_with_param_if_by_if_grad_param_excute_null():
  494. class MyIfByIfNet(nn.Cell):
  495. def __init__(self):
  496. super().__init__()
  497. self.max = P.ReduceMax()
  498. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  499. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  500. def construct(self, a, b, x):
  501. out = self.zero
  502. if a < b:
  503. out = out + x + self.param * 2
  504. return out
  505. class GradNet(nn.Cell):
  506. def __init__(self, net):
  507. super(GradNet, self).__init__()
  508. self.net = net
  509. self.weights = ParameterTuple(net.trainable_params())
  510. @ms_function
  511. def construct(self, *inputs):
  512. return grad_by_list(self.net, self.weights)(*inputs)
  513. if_net = MyIfByIfNet()
  514. net = GradNet(if_net)
  515. idx = Tensor(np.array(4), dtype=ms.int32)
  516. end = Tensor(np.array(0), dtype=ms.int32)
  517. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  518. net(idx, end, x)
  519. def test_if_by_if_return_inside_grad():
  520. class MyIfByIfNet(nn.Cell):
  521. def __init__(self):
  522. super().__init__()
  523. self.max = P.ReduceMax()
  524. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  525. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  526. def construct(self, a, b, x):
  527. out = self.zero
  528. if a < b:
  529. return out + x + self.param
  530. if a == b:
  531. return out + self.param * 2
  532. return out + self.param * 3
  533. class GradNet(nn.Cell):
  534. def __init__(self, net):
  535. super(GradNet, self).__init__()
  536. self.net = net
  537. self.weights = ParameterTuple(net.trainable_params())
  538. @ms_function
  539. def construct(self, *inputs):
  540. return grad_by_list(self.net, self.weights)(*inputs)
  541. if_net = MyIfByIfNet()
  542. net = GradNet(if_net)
  543. idx = Tensor(np.array(1), dtype=ms.int32)
  544. end = Tensor(np.array(0), dtype=ms.int32)
  545. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  546. net(idx, end, x)
  547. def test_if_by_if_forward():
  548. class MyIfByIfNet(nn.Cell):
  549. def __init__(self):
  550. super().__init__()
  551. self.add = P.Add()
  552. self.sub = P.Sub()
  553. self.mul = P.Mul()
  554. self.div = P.RealDiv()
  555. @ms_function
  556. def construct(self, a, b, x):
  557. if a < b:
  558. a = self.add(a, b)
  559. else:
  560. a = self.sub(a, b)
  561. if a == x:
  562. a = self.mul(a, b)
  563. else:
  564. a = self.div(a, b)
  565. if b == x:
  566. b = self.add(a, b)
  567. else:
  568. b = self.add(a, x)
  569. a = a * b
  570. out = a + b + x
  571. return out
  572. if_net = MyIfByIfNet()
  573. net = if_net
  574. idx = Tensor(np.array(2), dtype=ms.float32)
  575. end = Tensor(np.array(3), dtype=ms.float32)
  576. x = Tensor(np.array(4), dtype=ms.float32)
  577. net(idx, end, x)
  578. def test_if_by_if_forward_control_tuple_switch():
  579. """tuple_get from switch op will generate new switch inside to eliminate tuple_get"""
  580. class Branch3Net(nn.Cell):
  581. def __init__(self):
  582. super().__init__()
  583. self.add = P.Add()
  584. self.sub = P.Sub()
  585. self.mul = P.Mul()
  586. self.div = P.RealDiv()
  587. def construct(self, a, b, x):
  588. if b == x:
  589. b = self.add(a, b)
  590. else:
  591. b = self.add(a, x)
  592. return a, b, x
  593. class Branch2Net(nn.Cell):
  594. def __init__(self):
  595. super().__init__()
  596. self.add = P.Add()
  597. self.sub = P.Sub()
  598. self.mul = P.Mul()
  599. self.div = P.RealDiv()
  600. self.net = Branch3Net()
  601. def construct(self, a, b, x):
  602. if a == x:
  603. a = self.mul(a, b)
  604. else:
  605. a = self.div(a, b)
  606. return self.net(a, b, x)
  607. class MyIfByIfNet(nn.Cell):
  608. def __init__(self):
  609. super().__init__()
  610. self.add = P.Add()
  611. self.sub = P.Sub()
  612. self.mul = P.Mul()
  613. self.div = P.RealDiv()
  614. self.net = Branch2Net()
  615. @ms_function
  616. def construct(self, a, b, x):
  617. if a < b:
  618. a = self.add(a, b)
  619. else:
  620. a = self.sub(a, b)
  621. a, b, x = self.net(a, b, x)
  622. a = a * b
  623. out = a + b + x
  624. return out
  625. if_net = MyIfByIfNet()
  626. net = if_net
  627. idx = Tensor(np.array(2), dtype=ms.float32)
  628. end = Tensor(np.array(3), dtype=ms.float32)
  629. x = Tensor(np.array(0), dtype=ms.float32)
  630. net(idx, end, x)
  631. def test_if_by_if_forward_control_inside_net():
  632. class Branch3Net(nn.Cell):
  633. def __init__(self):
  634. super().__init__()
  635. self.add = P.Add()
  636. self.sub = P.Sub()
  637. self.mul = P.Mul()
  638. self.div = P.RealDiv()
  639. def construct(self, a, b, x):
  640. if b == x:
  641. b = self.add(a, b)
  642. else:
  643. b = self.add(a, x)
  644. a = a * b
  645. out = a + b + x
  646. return out
  647. class Branch2Net(nn.Cell):
  648. def __init__(self):
  649. super().__init__()
  650. self.add = P.Add()
  651. self.sub = P.Sub()
  652. self.mul = P.Mul()
  653. self.div = P.RealDiv()
  654. self.net = Branch3Net()
  655. def construct(self, a, b, x):
  656. if a == x:
  657. a = self.mul(a, b)
  658. else:
  659. a = self.div(a, b)
  660. return self.net(a, b, x)
  661. class MyIfByIfNet(nn.Cell):
  662. def __init__(self):
  663. super().__init__()
  664. self.add = P.Add()
  665. self.sub = P.Sub()
  666. self.mul = P.Mul()
  667. self.div = P.RealDiv()
  668. self.net = Branch2Net()
  669. @ms_function
  670. def construct(self, a, b, x):
  671. if a < b:
  672. a = self.add(a, b)
  673. else:
  674. a = self.sub(a, b)
  675. out = self.net(a, b, x)
  676. return out
  677. if_net = MyIfByIfNet()
  678. net = if_net
  679. idx = Tensor(np.array(2), dtype=ms.float32)
  680. end = Tensor(np.array(3), dtype=ms.float32)
  681. x = Tensor(np.array(0), dtype=ms.float32)
  682. net(idx, end, x)
  683. def test_if_by_if_forward_use_namespace():
  684. class MyIfByIfNet(nn.Cell):
  685. def __init__(self):
  686. super().__init__()
  687. self.add = P.Add()
  688. self.sub = P.Sub()
  689. self.mul = P.Mul()
  690. self.div = P.RealDiv()
  691. @ms_function
  692. def construct(self, a, b, x):
  693. if a < b:
  694. a = P.Add()(a, b)
  695. else:
  696. a = P.Sub()(a, b)
  697. if a == x:
  698. a = P.Mul()(a, b)
  699. else:
  700. a = P.RealDiv()(a, b)
  701. if b == x:
  702. b = P.Add()(a, b)
  703. else:
  704. b = P.Add()(a, x)
  705. a = a * b
  706. out = a + b + x
  707. return out
  708. if_net = MyIfByIfNet()
  709. net = if_net
  710. idx = Tensor(np.array(2), dtype=ms.float32)
  711. end = Tensor(np.array(3), dtype=ms.float32)
  712. x = Tensor(np.array(0), dtype=ms.float32)
  713. net(idx, end, x)
  714. def test_if_by_if_forward_use_global_op():
  715. class MyIfByIfNet(nn.Cell):
  716. def __init__(self):
  717. super().__init__()
  718. self.add = P.Add()
  719. self.sub = P.Sub()
  720. self.mul = P.Mul()
  721. self.div = P.RealDiv()
  722. @ms_function
  723. def construct(self, a, b, x):
  724. add = P.Add()
  725. sub = P.Sub()
  726. mul = P.Mul()
  727. div = P.RealDiv()
  728. if a < b:
  729. a = add(a, b)
  730. else:
  731. a = sub(a, b)
  732. if a == x:
  733. a = mul(a, b)
  734. else:
  735. a = div(a, b)
  736. if b == x:
  737. b = add(a, b)
  738. else:
  739. b = add(a, x)
  740. a = a * b
  741. out = a + b + x
  742. return out
  743. if_net = MyIfByIfNet()
  744. net = if_net
  745. idx = Tensor(np.array(2), dtype=ms.float32)
  746. end = Tensor(np.array(3), dtype=ms.float32)
  747. x = Tensor(np.array(0), dtype=ms.float32)
  748. net(idx, end, x)
  749. def test_for_with_if_by_if_forward():
  750. class MyIfByIfNet(nn.Cell):
  751. def __init__(self):
  752. super().__init__()
  753. self.add = P.Add()
  754. self.sub = P.Sub()
  755. @ms_function
  756. def construct(self, a, b, x):
  757. for _ in range(0, 4):
  758. if a < b:
  759. a = self.add(a, b)
  760. else:
  761. b = self.sub(b, x)
  762. a = a * b
  763. out = a + b + x
  764. return out
  765. if_net = MyIfByIfNet()
  766. net = if_net
  767. idx = Tensor(np.array(2), dtype=ms.float32)
  768. end = Tensor(np.array(3), dtype=ms.float32)
  769. x = Tensor(np.array(0), dtype=ms.float32)
  770. net(idx, end, x)
  771. def test_for_with_if_by_if_forward_namespace():
  772. class MyIfByIfNet(nn.Cell):
  773. def __init__(self):
  774. super().__init__()
  775. self.add = P.Add()
  776. self.sub = P.Sub()
  777. self.mul = P.Mul()
  778. self.div = P.RealDiv()
  779. @ms_function
  780. def construct(self, a, b, x):
  781. for _ in range(0, 6):
  782. if a < b:
  783. a = P.Add()(a, b)
  784. else:
  785. b = P.Sub()(b, x)
  786. a = a * b
  787. out = a + b + x
  788. return out
  789. if_net = MyIfByIfNet()
  790. net = if_net
  791. idx = Tensor(np.array(2), dtype=ms.float32)
  792. end = Tensor(np.array(3), dtype=ms.float32)
  793. x = Tensor(np.array(0), dtype=ms.float32)
  794. net(idx, end, x)
  795. def test_if_by_if_forward_const_branch_inner():
  796. class MyIfByIfNet(nn.Cell):
  797. def __init__(self):
  798. super().__init__()
  799. self.add = P.Add()
  800. self.sub = P.Sub()
  801. self.mul = P.Mul()
  802. self.div = P.RealDiv()
  803. @ms_function
  804. def construct(self, a, b, x):
  805. add = P.Add()
  806. sub = P.Sub()
  807. mul = P.Mul()
  808. div = P.RealDiv()
  809. if a < b:
  810. a = add(a, b)
  811. else:
  812. a = sub(a, b)
  813. if 2 > 1:
  814. a = mul(a, b)
  815. else:
  816. a = div(a, b)
  817. if b == x:
  818. b = add(a, b)
  819. else:
  820. b = add(a, x)
  821. a = a * b
  822. out = a + b + x
  823. return out
  824. if_net = MyIfByIfNet()
  825. net = if_net
  826. idx = Tensor(np.array(2), dtype=ms.float32)
  827. end = Tensor(np.array(3), dtype=ms.float32)
  828. x = Tensor(np.array(0), dtype=ms.float32)
  829. net(idx, end, x)
  830. def test_if_by_if_forward_all_const_branch():
  831. class MyIfByIfNet(nn.Cell):
  832. def __init__(self):
  833. super().__init__()
  834. self.add = P.Add()
  835. self.sub = P.Sub()
  836. self.mul = P.Mul()
  837. self.div = P.RealDiv()
  838. @ms_function
  839. def construct(self, a, b, x):
  840. add = P.Add()
  841. sub = P.Sub()
  842. mul = P.Mul()
  843. div = P.RealDiv()
  844. if 2 < 12:
  845. a = add(a, b)
  846. else:
  847. a = sub(a, b)
  848. if 2 > 1:
  849. a = mul(a, b)
  850. else:
  851. a = div(a, b)
  852. if 2 == 1:
  853. b = add(a, b)
  854. else:
  855. b = add(a, x)
  856. a = a * b
  857. out = a + b + x
  858. return out
  859. if_net = MyIfByIfNet()
  860. net = if_net
  861. idx = Tensor(np.array(2), dtype=ms.float32)
  862. end = Tensor(np.array(3), dtype=ms.float32)
  863. x = Tensor(np.array(0), dtype=ms.float32)
  864. net(idx, end, x)