You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cont_cases.py 32 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test control ops """
  16. import numpy as np
  17. from mindspore import dtype as ms
  18. from mindspore import Tensor
  19. from mindspore import context
  20. from mindspore import nn
  21. from mindspore import ms_function
  22. from mindspore.common.parameter import Parameter, ParameterTuple
  23. from mindspore.ops import composite as C
  24. from mindspore.ops import operations as P
  25. # from tests.vm_impl.math_ops_vm_impl import *
  26. # from tests.vm_impl.vm_interface import *
  27. # from tests.vm_impl import *
  28. grad_by_list = C.GradOperation(get_by_list=True)
  29. grad_all = C.GradOperation(get_all=True)
  30. def setup_module():
  31. context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False)
  32. def test_while_with_param_forward_with_const_branch():
  33. class MyWhileNet(nn.Cell):
  34. def __init__(self):
  35. super().__init__()
  36. self.max = P.ReduceMax()
  37. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  38. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  39. self.reduce = P.ReduceSum()
  40. @ms_function
  41. def construct(self, idx, end, x):
  42. out = self.zero
  43. while idx < end:
  44. if 2 > 1:
  45. out = out + self.param
  46. else:
  47. out = out + idx + self.param
  48. idx = idx + 1
  49. return out
  50. while_net = MyWhileNet()
  51. net = while_net
  52. idx = Tensor(np.array(0), dtype=ms.int32)
  53. end = Tensor(np.array(4), dtype=ms.int32)
  54. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  55. net(idx, end, x)
  56. def test_while_opt_endless():
  57. """endless during optimization case"""
  58. class MyWhileNet(nn.Cell):
  59. def __init__(self):
  60. super().__init__()
  61. self.max = P.ReduceMax()
  62. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  63. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  64. self.reduce = P.ReduceSum()
  65. self.addn = P.AddN()
  66. def construct(self, idx, end, x):
  67. addn1 = self.addn((x, x, x))
  68. out = addn1
  69. while idx < end:
  70. out = self.addn((out, addn1))
  71. idx = idx + 1
  72. out = self.addn((out, x))
  73. return out
  74. class GradNet(nn.Cell):
  75. def __init__(self, net):
  76. super(GradNet, self).__init__()
  77. self.net = net
  78. @ms_function
  79. def construct(self, *inputs):
  80. return grad_all(self.net)(*inputs)
  81. while_net = MyWhileNet()
  82. net = GradNet(while_net)
  83. idx = Tensor(np.array(0), dtype=ms.int32)
  84. end = Tensor(np.array(4), dtype=ms.int32)
  85. x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32)
  86. net(idx, end, x)
  87. def test_no_while_call():
  88. class MyWhileNet(nn.Cell):
  89. def __init__(self):
  90. super().__init__()
  91. self.max = P.ReduceMax()
  92. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  93. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  94. self.reduce = P.ReduceSum()
  95. @ms_function
  96. def construct(self, idx, end, x):
  97. out = self.zero
  98. if 2 > 1:
  99. out = out + self.param
  100. else:
  101. out = out + idx + self.param
  102. return out
  103. while_net = MyWhileNet()
  104. net = while_net
  105. idx = Tensor(np.array(0), dtype=ms.int32)
  106. end = Tensor(np.array(4), dtype=ms.int32)
  107. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  108. net(idx, end, x)
  109. def test_while_with_param_grad_with_const_branch():
  110. class MyWhileNet(nn.Cell):
  111. def __init__(self):
  112. super().__init__()
  113. self.max = P.ReduceMax()
  114. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  115. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  116. self.reduce = P.ReduceSum()
  117. def construct(self, idx, end, x):
  118. out = self.zero
  119. while idx < end:
  120. if 2 > 1:
  121. out = out + self.param
  122. else:
  123. out = out + idx + self.param
  124. idx = idx + 1
  125. return out
  126. class GradNet(nn.Cell):
  127. def __init__(self, net):
  128. super(GradNet, self).__init__()
  129. self.net = net
  130. self.weights = ParameterTuple(net.trainable_params())
  131. @ms_function
  132. def construct(self, a, b, c):
  133. return grad_by_list(self.net, self.weights)(a, b, c)
  134. while_net = MyWhileNet()
  135. net = GradNet(while_net)
  136. idx = Tensor(np.array(0), dtype=ms.int32)
  137. end = Tensor(np.array(4), dtype=ms.int32)
  138. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  139. net(idx, end, x)
  140. def test_for_while_with_param_grad_with_const_branch():
  141. class MyWhileNet(nn.Cell):
  142. def __init__(self):
  143. super().__init__()
  144. self.max = P.ReduceMax()
  145. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  146. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  147. self.reduce = P.ReduceSum()
  148. self.start = Tensor(np.array(0), dtype=ms.int32)
  149. def construct(self, idx, end, x):
  150. out = self.zero
  151. for _ in range(0, 2):
  152. idx = self.start
  153. while idx < end:
  154. if 2 > 1:
  155. out = out + self.param
  156. else:
  157. out = out + idx + self.param
  158. idx = idx + 1
  159. return out
  160. class GradNet(nn.Cell):
  161. def __init__(self, net):
  162. super(GradNet, self).__init__()
  163. self.net = net
  164. self.weights = ParameterTuple(net.trainable_params())
  165. @ms_function
  166. def construct(self, a, b, c):
  167. return grad_by_list(self.net, self.weights)(a, b, c)
  168. while_net = MyWhileNet()
  169. net = GradNet(while_net)
  170. idx = Tensor(np.array(0), dtype=ms.int32)
  171. end = Tensor(np.array(4), dtype=ms.int32)
  172. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  173. net(idx, end, x)
  174. def test_for_while_with_param_grad_basic():
  175. class MyWhileNet(nn.Cell):
  176. def __init__(self):
  177. super().__init__()
  178. self.max = P.ReduceMax()
  179. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  180. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  181. self.reduce = P.ReduceSum()
  182. self.start = Tensor(np.array(0), dtype=ms.int32)
  183. def construct(self, idx, end, x):
  184. out = self.zero
  185. for _ in range(0, 2):
  186. idx = self.start
  187. while idx < end:
  188. out = out + self.param
  189. idx = idx + 1
  190. return out
  191. class GradNet(nn.Cell):
  192. def __init__(self, net):
  193. super(GradNet, self).__init__()
  194. self.net = net
  195. self.weights = ParameterTuple(net.trainable_params())
  196. @ms_function
  197. def construct(self, a, b, c):
  198. return grad_by_list(self.net, self.weights)(a, b, c)
  199. while_net = MyWhileNet()
  200. net = GradNet(while_net)
  201. idx = Tensor(np.array(0), dtype=ms.int32)
  202. end = Tensor(np.array(4), dtype=ms.int32)
  203. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  204. net(idx, end, x)
  205. def test_for_while_with_param_grad_normal():
  206. class MyWhileNet(nn.Cell):
  207. def __init__(self):
  208. super().__init__()
  209. self.max = P.ReduceMax()
  210. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  211. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  212. self.reduce = P.ReduceSum()
  213. self.start = Tensor(np.array(0), dtype=ms.int32)
  214. def construct(self, idx, end, x):
  215. out = x
  216. for _ in range(0, 2):
  217. idx = self.start
  218. while idx < end:
  219. out = out + self.param
  220. idx = idx + 1
  221. return out
  222. class GradNet(nn.Cell):
  223. def __init__(self, net):
  224. super(GradNet, self).__init__()
  225. self.net = net
  226. self.weights = ParameterTuple(net.trainable_params())
  227. @ms_function
  228. def construct(self, a, b, c):
  229. return grad_by_list(self.net, self.weights)(a, b, c)
  230. while_net = MyWhileNet()
  231. net = GradNet(while_net)
  232. idx = Tensor(np.array(0), dtype=ms.int32)
  233. end = Tensor(np.array(4), dtype=ms.int32)
  234. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  235. net(idx, end, x)
  236. def test_while_with_param_basic_grad():
  237. class MyWhileNet(nn.Cell):
  238. def __init__(self):
  239. super().__init__()
  240. self.max = P.ReduceMax()
  241. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  242. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  243. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  244. def construct(self, idx, end, x):
  245. out = self.zero
  246. while idx < end:
  247. out = out + self.param
  248. idx = idx + 1
  249. return out + self.param
  250. class GradNet(nn.Cell):
  251. def __init__(self, net):
  252. super(GradNet, self).__init__()
  253. self.net = net
  254. self.weights = ParameterTuple(net.trainable_params())
  255. @ms_function
  256. def construct(self, a, b, c):
  257. return grad_by_list(self.net, self.weights)(a, b, c)
  258. while_net = MyWhileNet()
  259. net = GradNet(while_net)
  260. idx = Tensor(np.array(0), dtype=ms.int32)
  261. end = Tensor(np.array(3), dtype=ms.int32)
  262. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  263. net(idx, end, x)
  264. def test_while_with_param_basic_grad_mul():
  265. class MyWhileNet(nn.Cell):
  266. def __init__(self):
  267. super().__init__()
  268. self.max = P.ReduceMax()
  269. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  270. self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32)
  271. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  272. def construct(self, idx, end, x):
  273. out = self.zero
  274. while idx < end:
  275. out = out * self.param
  276. idx = idx + 1
  277. return out + self.param
  278. class GradNet(nn.Cell):
  279. def __init__(self, net):
  280. super(GradNet, self).__init__()
  281. self.net = net
  282. self.weights = ParameterTuple(net.trainable_params())
  283. @ms_function
  284. def construct(self, a, b, c):
  285. return grad_by_list(self.net, self.weights)(a, b, c)
  286. while_net = MyWhileNet()
  287. net = GradNet(while_net)
  288. idx = Tensor(np.array(0), dtype=ms.int32)
  289. end = Tensor(np.array(3), dtype=ms.int32)
  290. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  291. net(idx, end, x)
  292. def test_while_with_param_basic_grad_two():
  293. class MyWhileNet(nn.Cell):
  294. def __init__(self):
  295. super().__init__()
  296. self.max = P.ReduceMax()
  297. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  298. self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
  299. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  300. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  301. def construct(self, idx, end, x):
  302. out = self.zero
  303. while idx < end:
  304. out = out + self.param + self.weight
  305. idx = idx + 1
  306. return out + self.param
  307. class GradNet(nn.Cell):
  308. def __init__(self, net):
  309. super(GradNet, self).__init__()
  310. self.net = net
  311. self.weights = ParameterTuple(net.trainable_params())
  312. @ms_function
  313. def construct(self, a, b, c):
  314. return grad_by_list(self.net, self.weights)(a, b, c)
  315. while_net = MyWhileNet()
  316. net = GradNet(while_net)
  317. idx = Tensor(np.array(0), dtype=ms.int32)
  318. end = Tensor(np.array(3), dtype=ms.int32)
  319. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  320. net(idx, end, x)
  321. def test_while_with_param_basic_grad_three():
  322. class MyWhileNet(nn.Cell):
  323. def __init__(self):
  324. super().__init__()
  325. self.max = P.ReduceMax()
  326. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  327. self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
  328. self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key")
  329. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  330. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  331. def construct(self, idx, end, x):
  332. out = self.zero
  333. while idx < end:
  334. out = out + self.param + self.weight + self.key
  335. idx = idx + 1
  336. return out + self.param
  337. class GradNet(nn.Cell):
  338. def __init__(self, net):
  339. super(GradNet, self).__init__()
  340. self.net = net
  341. self.weights = ParameterTuple(net.trainable_params())
  342. @ms_function
  343. def construct(self, a, b, c):
  344. return grad_by_list(self.net, self.weights)(a, b, c)
  345. while_net = MyWhileNet()
  346. net = GradNet(while_net)
  347. idx = Tensor(np.array(0), dtype=ms.int32)
  348. end = Tensor(np.array(3), dtype=ms.int32)
  349. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  350. net(idx, end, x)
  351. def test_while_if_with_param_grad():
  352. class MyWhileNet(nn.Cell):
  353. def __init__(self):
  354. super().__init__()
  355. self.max = P.ReduceMax()
  356. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  357. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  358. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  359. def construct(self, idx, end, x):
  360. out = self.zero
  361. while idx < end:
  362. if self.max(out) < self.max(x):
  363. out = out + self.param * 2
  364. else:
  365. out = out + self.param
  366. idx = idx + 1
  367. return out + self.param
  368. class GradNet(nn.Cell):
  369. def __init__(self, net):
  370. super(GradNet, self).__init__()
  371. self.net = net
  372. self.weights = ParameterTuple(net.trainable_params())
  373. @ms_function
  374. def construct(self, a, b, c):
  375. return grad_by_list(self.net, self.weights)(a, b, c)
  376. while_net = MyWhileNet()
  377. net = GradNet(while_net)
  378. idx = Tensor(np.array(0), dtype=ms.int32)
  379. end = Tensor(np.array(3), dtype=ms.int32)
  380. x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
  381. net(idx, end, x)
  382. def test_while_with_param_grad_not_enter_while():
  383. class MyWhileNet(nn.Cell):
  384. def __init__(self):
  385. super().__init__()
  386. self.max = P.ReduceMax()
  387. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  388. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  389. def construct(self, idx, end, x):
  390. out = self.zero
  391. while idx < end:
  392. out = out + self.param * 3
  393. idx = idx + 1
  394. return out + self.param
  395. class GradNet(nn.Cell):
  396. def __init__(self, net):
  397. super(GradNet, self).__init__()
  398. self.net = net
  399. self.weights = ParameterTuple(net.trainable_params())
  400. @ms_function
  401. def construct(self, a, b, c):
  402. return grad_by_list(self.net, self.weights)(a, b, c)
  403. while_net = MyWhileNet()
  404. net = GradNet(while_net)
  405. idx = Tensor(np.array(3), dtype=ms.int32)
  406. end = Tensor(np.array(0), dtype=ms.int32)
  407. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  408. net(idx, end, x)
  409. def test_with_param_if_by_if_forward():
  410. class MyIfByIfNet(nn.Cell):
  411. def __init__(self):
  412. super().__init__()
  413. self.max = P.ReduceMax()
  414. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  415. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  416. @ms_function
  417. def construct(self, a, b, x):
  418. out = self.zero
  419. if a < b:
  420. out = out + x + self.param
  421. else:
  422. out = out + x
  423. if a == b:
  424. out = out + x*3 + self.param
  425. else:
  426. out = out + x*2
  427. return out
  428. if_net = MyIfByIfNet()
  429. net = if_net
  430. idx = Tensor(np.array(0), dtype=ms.int32)
  431. end = Tensor(np.array(4), dtype=ms.int32)
  432. x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
  433. net(idx, end, x)
  434. def test_with_param_if_by_if_grad_inputs():
  435. class MyIfByIfNet(nn.Cell):
  436. def __init__(self):
  437. super().__init__()
  438. self.max = P.ReduceMax()
  439. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  440. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  441. def construct(self, a, b, x):
  442. out = self.zero
  443. if a < b:
  444. out = out + x + self.param * 4
  445. if a == b:
  446. out = out + x*3 + self.param * 3
  447. return out
  448. class GradNet(nn.Cell):
  449. def __init__(self, net):
  450. super(GradNet, self).__init__()
  451. self.net = net
  452. @ms_function
  453. def construct(self, *inputs):
  454. return grad_all(self.net)(*inputs)
  455. if_net = MyIfByIfNet()
  456. net = GradNet(if_net)
  457. idx = Tensor(np.array(0), dtype=ms.int32)
  458. end = Tensor(np.array(0), dtype=ms.int32)
  459. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  460. net(idx, end, x)
  461. def test_with_param_if_by_if_grad_parameter():
  462. class MyIfByIfNet(nn.Cell):
  463. def __init__(self):
  464. super().__init__()
  465. self.max = P.ReduceMax()
  466. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  467. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  468. def construct(self, a, b, x):
  469. out = self.zero
  470. if a < b:
  471. out = out + x + self.param * 2
  472. if a == b:
  473. out = out + x*3 + self.param
  474. return out
  475. class GradNet(nn.Cell):
  476. def __init__(self, net):
  477. super(GradNet, self).__init__()
  478. self.net = net
  479. self.weights = ParameterTuple(net.trainable_params())
  480. @ms_function
  481. def construct(self, *inputs):
  482. return grad_by_list(self.net, self.weights)(*inputs)
  483. if_net = MyIfByIfNet()
  484. net = GradNet(if_net)
  485. idx = Tensor(np.array(0), dtype=ms.int32)
  486. end = Tensor(np.array(2), dtype=ms.int32)
  487. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  488. net(idx, end, x)
  489. def test_with_param_if_by_if_grad_param_excute_null():
  490. class MyIfByIfNet(nn.Cell):
  491. def __init__(self):
  492. super().__init__()
  493. self.max = P.ReduceMax()
  494. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  495. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  496. def construct(self, a, b, x):
  497. out = self.zero
  498. if a < b:
  499. out = out + x + self.param * 2
  500. return out
  501. class GradNet(nn.Cell):
  502. def __init__(self, net):
  503. super(GradNet, self).__init__()
  504. self.net = net
  505. self.weights = ParameterTuple(net.trainable_params())
  506. @ms_function
  507. def construct(self, *inputs):
  508. return grad_by_list(self.net, self.weights)(*inputs)
  509. if_net = MyIfByIfNet()
  510. net = GradNet(if_net)
  511. idx = Tensor(np.array(4), dtype=ms.int32)
  512. end = Tensor(np.array(0), dtype=ms.int32)
  513. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  514. net(idx, end, x)
  515. def test_if_by_if_return_inside_grad():
  516. class MyIfByIfNet(nn.Cell):
  517. def __init__(self):
  518. super().__init__()
  519. self.max = P.ReduceMax()
  520. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  521. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  522. def construct(self, a, b, x):
  523. out = self.zero
  524. if a < b:
  525. return out + x + self.param
  526. if a == b:
  527. return out + self.param * 2
  528. return out + self.param * 3
  529. class GradNet(nn.Cell):
  530. def __init__(self, net):
  531. super(GradNet, self).__init__()
  532. self.net = net
  533. self.weights = ParameterTuple(net.trainable_params())
  534. @ms_function
  535. def construct(self, *inputs):
  536. return grad_by_list(self.net, self.weights)(*inputs)
  537. if_net = MyIfByIfNet()
  538. net = GradNet(if_net)
  539. idx = Tensor(np.array(1), dtype=ms.int32)
  540. end = Tensor(np.array(0), dtype=ms.int32)
  541. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  542. net(idx, end, x)
  543. def test_if_by_if_forward():
  544. class MyIfByIfNet(nn.Cell):
  545. def __init__(self):
  546. super().__init__()
  547. self.add = P.TensorAdd()
  548. self.sub = P.Sub()
  549. self.mul = P.Mul()
  550. self.div = P.RealDiv()
  551. @ms_function
  552. def construct(self, a, b, x):
  553. if a < b:
  554. a = self.add(a, b)
  555. else:
  556. a = self.sub(a, b)
  557. if a == x:
  558. a = self.mul(a, b)
  559. else:
  560. a = self.div(a, b)
  561. if b == x:
  562. b = self.add(a, b)
  563. else:
  564. b = self.add(a, x)
  565. a = a * b
  566. out = a + b + x
  567. return out
  568. if_net = MyIfByIfNet()
  569. net = if_net
  570. idx = Tensor(np.array(2), dtype=ms.float32)
  571. end = Tensor(np.array(3), dtype=ms.float32)
  572. x = Tensor(np.array(4), dtype=ms.float32)
  573. net(idx, end, x)
  574. def test_if_by_if_forward_control_tuple_switch():
  575. """tuple_get from swtich op will generate new switch inside to eliminate tuple_get"""
  576. class Branch3Net(nn.Cell):
  577. def __init__(self):
  578. super().__init__()
  579. self.add = P.TensorAdd()
  580. self.sub = P.Sub()
  581. self.mul = P.Mul()
  582. self.div = P.RealDiv()
  583. def construct(self, a, b, x):
  584. if b == x:
  585. b = self.add(a, b)
  586. else:
  587. b = self.add(a, x)
  588. return a, b, x
  589. class Branch2Net(nn.Cell):
  590. def __init__(self):
  591. super().__init__()
  592. self.add = P.TensorAdd()
  593. self.sub = P.Sub()
  594. self.mul = P.Mul()
  595. self.div = P.RealDiv()
  596. self.net = Branch3Net()
  597. def construct(self, a, b, x):
  598. if a == x:
  599. a = self.mul(a, b)
  600. else:
  601. a = self.div(a, b)
  602. return self.net(a, b, x)
  603. class MyIfByIfNet(nn.Cell):
  604. def __init__(self):
  605. super().__init__()
  606. self.add = P.TensorAdd()
  607. self.sub = P.Sub()
  608. self.mul = P.Mul()
  609. self.div = P.RealDiv()
  610. self.net = Branch2Net()
  611. @ms_function
  612. def construct(self, a, b, x):
  613. if a < b:
  614. a = self.add(a, b)
  615. else:
  616. a = self.sub(a, b)
  617. a, b, x = self.net(a, b, x)
  618. a = a * b
  619. out = a + b + x
  620. return out
  621. if_net = MyIfByIfNet()
  622. net = if_net
  623. idx = Tensor(np.array(2), dtype=ms.float32)
  624. end = Tensor(np.array(3), dtype=ms.float32)
  625. x = Tensor(np.array(0), dtype=ms.float32)
  626. net(idx, end, x)
  627. def test_if_by_if_forward_control_inside_net():
  628. class Branch3Net(nn.Cell):
  629. def __init__(self):
  630. super().__init__()
  631. self.add = P.TensorAdd()
  632. self.sub = P.Sub()
  633. self.mul = P.Mul()
  634. self.div = P.RealDiv()
  635. def construct(self, a, b, x):
  636. if b == x:
  637. b = self.add(a, b)
  638. else:
  639. b = self.add(a, x)
  640. a = a * b
  641. out = a + b + x
  642. return out
  643. class Branch2Net(nn.Cell):
  644. def __init__(self):
  645. super().__init__()
  646. self.add = P.TensorAdd()
  647. self.sub = P.Sub()
  648. self.mul = P.Mul()
  649. self.div = P.RealDiv()
  650. self.net = Branch3Net()
  651. def construct(self, a, b, x):
  652. if a == x:
  653. a = self.mul(a, b)
  654. else:
  655. a = self.div(a, b)
  656. return self.net(a, b, x)
  657. class MyIfByIfNet(nn.Cell):
  658. def __init__(self):
  659. super().__init__()
  660. self.add = P.TensorAdd()
  661. self.sub = P.Sub()
  662. self.mul = P.Mul()
  663. self.div = P.RealDiv()
  664. self.net = Branch2Net()
  665. @ms_function
  666. def construct(self, a, b, x):
  667. if a < b:
  668. a = self.add(a, b)
  669. else:
  670. a = self.sub(a, b)
  671. out = self.net(a, b, x)
  672. return out
  673. if_net = MyIfByIfNet()
  674. net = if_net
  675. idx = Tensor(np.array(2), dtype=ms.float32)
  676. end = Tensor(np.array(3), dtype=ms.float32)
  677. x = Tensor(np.array(0), dtype=ms.float32)
  678. net(idx, end, x)
  679. def test_if_by_if_forward_use_namespace():
  680. class MyIfByIfNet(nn.Cell):
  681. def __init__(self):
  682. super().__init__()
  683. self.add = P.TensorAdd()
  684. self.sub = P.Sub()
  685. self.mul = P.Mul()
  686. self.div = P.RealDiv()
  687. @ms_function
  688. def construct(self, a, b, x):
  689. if a < b:
  690. a = P.TensorAdd()(a, b)
  691. else:
  692. a = P.Sub()(a, b)
  693. if a == x:
  694. a = P.Mul()(a, b)
  695. else:
  696. a = P.RealDiv()(a, b)
  697. if b == x:
  698. b = P.TensorAdd()(a, b)
  699. else:
  700. b = P.TensorAdd()(a, x)
  701. a = a * b
  702. out = a + b + x
  703. return out
  704. if_net = MyIfByIfNet()
  705. net = if_net
  706. idx = Tensor(np.array(2), dtype=ms.float32)
  707. end = Tensor(np.array(3), dtype=ms.float32)
  708. x = Tensor(np.array(0), dtype=ms.float32)
  709. net(idx, end, x)
  710. def test_if_by_if_forward_use_global_op():
  711. class MyIfByIfNet(nn.Cell):
  712. def __init__(self):
  713. super().__init__()
  714. self.add = P.TensorAdd()
  715. self.sub = P.Sub()
  716. self.mul = P.Mul()
  717. self.div = P.RealDiv()
  718. @ms_function
  719. def construct(self, a, b, x):
  720. add = P.TensorAdd()
  721. sub = P.Sub()
  722. mul = P.Mul()
  723. div = P.RealDiv()
  724. if a < b:
  725. a = add(a, b)
  726. else:
  727. a = sub(a, b)
  728. if a == x:
  729. a = mul(a, b)
  730. else:
  731. a = div(a, b)
  732. if b == x:
  733. b = add(a, b)
  734. else:
  735. b = add(a, x)
  736. a = a * b
  737. out = a + b + x
  738. return out
  739. if_net = MyIfByIfNet()
  740. net = if_net
  741. idx = Tensor(np.array(2), dtype=ms.float32)
  742. end = Tensor(np.array(3), dtype=ms.float32)
  743. x = Tensor(np.array(0), dtype=ms.float32)
  744. net(idx, end, x)
  745. def test_for_with_if_by_if_forward():
  746. class MyIfByIfNet(nn.Cell):
  747. def __init__(self):
  748. super().__init__()
  749. self.add = P.TensorAdd()
  750. self.sub = P.Sub()
  751. @ms_function
  752. def construct(self, a, b, x):
  753. for _ in range(0, 4):
  754. if a < b:
  755. a = self.add(a, b)
  756. else:
  757. b = self.sub(b, x)
  758. a = a * b
  759. out = a + b + x
  760. return out
  761. if_net = MyIfByIfNet()
  762. net = if_net
  763. idx = Tensor(np.array(2), dtype=ms.float32)
  764. end = Tensor(np.array(3), dtype=ms.float32)
  765. x = Tensor(np.array(0), dtype=ms.float32)
  766. net(idx, end, x)
  767. def test_for_with_if_by_if_forward_namespace():
  768. class MyIfByIfNet(nn.Cell):
  769. def __init__(self):
  770. super().__init__()
  771. self.add = P.TensorAdd()
  772. self.sub = P.Sub()
  773. self.mul = P.Mul()
  774. self.div = P.RealDiv()
  775. @ms_function
  776. def construct(self, a, b, x):
  777. for _ in range(0, 6):
  778. if a < b:
  779. a = P.TensorAdd()(a, b)
  780. else:
  781. b = P.Sub()(b, x)
  782. a = a * b
  783. out = a + b + x
  784. return out
  785. if_net = MyIfByIfNet()
  786. net = if_net
  787. idx = Tensor(np.array(2), dtype=ms.float32)
  788. end = Tensor(np.array(3), dtype=ms.float32)
  789. x = Tensor(np.array(0), dtype=ms.float32)
  790. net(idx, end, x)
  791. def test_if_by_if_forward_const_branch_inner():
  792. class MyIfByIfNet(nn.Cell):
  793. def __init__(self):
  794. super().__init__()
  795. self.add = P.TensorAdd()
  796. self.sub = P.Sub()
  797. self.mul = P.Mul()
  798. self.div = P.RealDiv()
  799. @ms_function
  800. def construct(self, a, b, x):
  801. add = P.TensorAdd()
  802. sub = P.Sub()
  803. mul = P.Mul()
  804. div = P.RealDiv()
  805. if a < b:
  806. a = add(a, b)
  807. else:
  808. a = sub(a, b)
  809. if 2 > 1:
  810. a = mul(a, b)
  811. else:
  812. a = div(a, b)
  813. if b == x:
  814. b = add(a, b)
  815. else:
  816. b = add(a, x)
  817. a = a * b
  818. out = a + b + x
  819. return out
  820. if_net = MyIfByIfNet()
  821. net = if_net
  822. idx = Tensor(np.array(2), dtype=ms.float32)
  823. end = Tensor(np.array(3), dtype=ms.float32)
  824. x = Tensor(np.array(0), dtype=ms.float32)
  825. net(idx, end, x)
  826. def test_if_by_if_forward_all_const_branch():
  827. class MyIfByIfNet(nn.Cell):
  828. def __init__(self):
  829. super().__init__()
  830. self.add = P.TensorAdd()
  831. self.sub = P.Sub()
  832. self.mul = P.Mul()
  833. self.div = P.RealDiv()
  834. @ms_function
  835. def construct(self, a, b, x):
  836. add = P.TensorAdd()
  837. sub = P.Sub()
  838. mul = P.Mul()
  839. div = P.RealDiv()
  840. if 2 < 12:
  841. a = add(a, b)
  842. else:
  843. a = sub(a, b)
  844. if 2 > 1:
  845. a = mul(a, b)
  846. else:
  847. a = div(a, b)
  848. if 2 == 1:
  849. b = add(a, b)
  850. else:
  851. b = add(a, x)
  852. a = a * b
  853. out = a + b + x
  854. return out
  855. if_net = MyIfByIfNet()
  856. net = if_net
  857. idx = Tensor(np.array(2), dtype=ms.float32)
  858. end = Tensor(np.array(3), dtype=ms.float32)
  859. x = Tensor(np.array(0), dtype=ms.float32)
  860. net(idx, end, x)