You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cont_grad.py 38 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test control ops """
  16. import numpy as np
  17. from mindspore import dtype as ms
  18. from mindspore import Tensor
  19. from mindspore import context
  20. from mindspore import nn
  21. from mindspore.common.parameter import Parameter, ParameterTuple
  22. from mindspore.ops import composite as C
  23. from mindspore.ops import operations as P
  24. # from tests.vm_impl.math_ops_vm_impl import *
  25. # from tests.vm_impl.vm_interface import *
  26. # from tests.vm_impl import *
  27. # context.set_context(save_graphs=True)
  28. grad_by_list = C.GradOperation(get_by_list=True)
  29. grad_all = C.GradOperation(get_all=True)
  30. def test_while_forward():
  31. class MyWhileNet(nn.Cell):
  32. def __init__(self):
  33. super().__init__()
  34. self.max = P.ReduceMax()
  35. def construct(self, idx, end, x):
  36. while idx < end:
  37. part = x[idx, :, :]
  38. max_num = self.max(part)
  39. x[idx, :, 0:2] = max_num
  40. idx = idx + 1
  41. return x
  42. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  43. net = MyWhileNet()
  44. idx = Tensor(np.array(0), dtype=ms.int32)
  45. end = Tensor(np.array(2), dtype=ms.int32)
  46. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  47. net(idx, end, x)
  48. def test_while_grad():
  49. class MyWhileNet(nn.Cell):
  50. def __init__(self):
  51. super().__init__()
  52. self.max = P.ReduceMax()
  53. def construct(self, idx, end, x):
  54. while idx < end:
  55. part = x[idx, :, :]
  56. max_num = self.max(part)
  57. x[idx, :, 0:2] = max_num
  58. idx = idx + 1
  59. return x
  60. class GradNet(nn.Cell):
  61. def __init__(self, net):
  62. super(GradNet, self).__init__()
  63. self.net = net
  64. def construct(self, *inputs):
  65. return grad_all(self.net)(*inputs)
  66. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  67. while_net = MyWhileNet()
  68. net = GradNet(while_net)
  69. idx = Tensor(np.array(0), dtype=ms.int32)
  70. end = Tensor(np.array(2), dtype=ms.int32)
  71. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  72. net(idx, end, x)
  73. def test_while_with_param_forward():
  74. class MyWhileNet(nn.Cell):
  75. def __init__(self):
  76. super().__init__()
  77. self.max = P.ReduceMax()
  78. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  79. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  80. def construct(self, idx, end, x):
  81. out = self.zero
  82. while idx < end:
  83. part = x[idx, :, :]
  84. max_num = self.max(part)
  85. x[idx, :, 0:2] = max_num
  86. out = out + x + self.param
  87. idx = idx + 1
  88. return out
  89. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  90. net = MyWhileNet()
  91. idx = Tensor(np.array(0), dtype=ms.int32)
  92. end = Tensor(np.array(2), dtype=ms.int32)
  93. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  94. net(idx, end, x)
  95. def test_while_endless_case():
  96. """endless case when optmization"""
  97. class MyWhileNet(nn.Cell):
  98. def __init__(self):
  99. super().__init__()
  100. self.max = P.ReduceMax()
  101. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  102. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  103. def construct(self, idx, end, x):
  104. out = self.zero
  105. while idx < end:
  106. part = x[idx, :, :]
  107. out = out + part
  108. idx = idx + 1
  109. return out
  110. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  111. net = MyWhileNet()
  112. idx = Tensor(np.array(0), dtype=ms.int32)
  113. end = Tensor(np.array(2), dtype=ms.int32)
  114. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  115. net(idx, end, x)
  116. def test_while_with_param_grad():
  117. class MyWhileNet(nn.Cell):
  118. def __init__(self):
  119. super().__init__()
  120. self.max = P.ReduceMax()
  121. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  122. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  123. def construct(self, idx, end, x):
  124. out = self.zero
  125. while idx < end:
  126. part = x[idx, :, :]
  127. max_num = self.max(part)
  128. x[idx, :, 0:2] = max_num
  129. out = out + x + self.param
  130. idx = idx + 1
  131. return out
  132. class GradNet(nn.Cell):
  133. def __init__(self, net):
  134. super(GradNet, self).__init__()
  135. self.net = net
  136. self.weights = ParameterTuple(net.trainable_params())
  137. def construct(self, a, b, c):
  138. return grad_by_list(self.net, self.weights)(a, b, c)
  139. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  140. while_net = MyWhileNet()
  141. net = GradNet(while_net)
  142. idx = Tensor(np.array(0), dtype=ms.int32)
  143. end = Tensor(np.array(2), dtype=ms.int32)
  144. x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
  145. net(idx, end, x)
  146. def test_while_with_param_forward_with_const_branch():
  147. class MyWhileNet(nn.Cell):
  148. def __init__(self):
  149. super().__init__()
  150. self.max = P.ReduceMax()
  151. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  152. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  153. self.reduce = P.ReduceSum()
  154. def construct(self, idx, end, x):
  155. out = self.zero
  156. while idx < end:
  157. if 2 > 1:
  158. out = out + self.param
  159. else:
  160. out = out + idx + self.param
  161. idx = idx + 1
  162. return out
  163. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  164. while_net = MyWhileNet()
  165. net = while_net
  166. idx = Tensor(np.array(0), dtype=ms.int32)
  167. end = Tensor(np.array(4), dtype=ms.int32)
  168. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  169. net(idx, end, x)
  170. def test_while_opt_endless():
  171. """endless during optimization case"""
  172. class MyWhileNet(nn.Cell):
  173. def __init__(self):
  174. super().__init__()
  175. self.max = P.ReduceMax()
  176. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  177. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  178. self.reduce = P.ReduceSum()
  179. self.addn = P.AddN()
  180. def construct(self, idx, end, x):
  181. addn1 = self.addn((x, x, x))
  182. out = addn1
  183. while idx < end:
  184. out = self.addn((out, addn1))
  185. idx = idx + 1
  186. out = self.addn((out, x))
  187. return out
  188. class GradNet(nn.Cell):
  189. def __init__(self, net):
  190. super(GradNet, self).__init__()
  191. self.net = net
  192. def construct(self, *inputs):
  193. return grad_all(self.net)(*inputs)
  194. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  195. while_net = MyWhileNet()
  196. net = GradNet(while_net)
  197. idx = Tensor(np.array(0), dtype=ms.int32)
  198. end = Tensor(np.array(4), dtype=ms.int32)
  199. x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32)
  200. net(idx, end, x)
  201. def test_no_while_call():
  202. class MyWhileNet(nn.Cell):
  203. def __init__(self):
  204. super().__init__()
  205. self.max = P.ReduceMax()
  206. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  207. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  208. self.reduce = P.ReduceSum()
  209. def construct(self, idx, end, x):
  210. out = self.zero
  211. if 2 > 1:
  212. out = out + self.param
  213. else:
  214. out = out + idx + self.param
  215. return out
  216. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  217. while_net = MyWhileNet()
  218. net = while_net
  219. idx = Tensor(np.array(0), dtype=ms.int32)
  220. end = Tensor(np.array(4), dtype=ms.int32)
  221. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  222. net(idx, end, x)
  223. def test_while_with_param_grad_with_const_branch():
  224. class MyWhileNet(nn.Cell):
  225. def __init__(self):
  226. super().__init__()
  227. self.max = P.ReduceMax()
  228. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  229. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  230. self.reduce = P.ReduceSum()
  231. def construct(self, idx, end, x):
  232. out = self.zero
  233. while idx < end:
  234. if 2 > 1:
  235. out = out + self.param
  236. else:
  237. out = out + idx + self.param
  238. idx = idx + 1
  239. return out
  240. class GradNet(nn.Cell):
  241. def __init__(self, net):
  242. super(GradNet, self).__init__()
  243. self.net = net
  244. self.weights = ParameterTuple(net.trainable_params())
  245. def construct(self, a, b, c):
  246. return grad_by_list(self.net, self.weights)(a, b, c)
  247. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  248. while_net = MyWhileNet()
  249. net = GradNet(while_net)
  250. idx = Tensor(np.array(0), dtype=ms.int32)
  251. end = Tensor(np.array(4), dtype=ms.int32)
  252. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  253. net(idx, end, x)
  254. def test_for_while_with_param_grad_with_const_branch():
  255. class MyWhileNet(nn.Cell):
  256. def __init__(self):
  257. super().__init__()
  258. self.max = P.ReduceMax()
  259. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  260. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  261. self.reduce = P.ReduceSum()
  262. self.start = Tensor(np.array(0), dtype=ms.int32)
  263. def construct(self, idx, end, x):
  264. out = self.zero
  265. for _ in range(0, 2):
  266. idx = self.start
  267. while idx < end:
  268. if 2 > 1:
  269. out = out + self.param
  270. else:
  271. out = out + idx + self.param
  272. idx = idx + 1
  273. return out
  274. class GradNet(nn.Cell):
  275. def __init__(self, net):
  276. super(GradNet, self).__init__()
  277. self.net = net
  278. self.weights = ParameterTuple(net.trainable_params())
  279. def construct(self, a, b, c):
  280. return grad_by_list(self.net, self.weights)(a, b, c)
  281. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  282. while_net = MyWhileNet()
  283. net = GradNet(while_net)
  284. idx = Tensor(np.array(0), dtype=ms.int32)
  285. end = Tensor(np.array(4), dtype=ms.int32)
  286. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  287. net(idx, end, x)
  288. def test_for_while_with_param_grad_basic():
  289. class MyWhileNet(nn.Cell):
  290. def __init__(self):
  291. super().__init__()
  292. self.max = P.ReduceMax()
  293. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  294. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  295. self.reduce = P.ReduceSum()
  296. self.start = Tensor(np.array(0), dtype=ms.int32)
  297. def construct(self, idx, end, x):
  298. out = self.zero
  299. for _ in range(0, 2):
  300. idx = self.start
  301. while idx < end:
  302. out = out + self.param
  303. idx = idx + 1
  304. return out
  305. class GradNet(nn.Cell):
  306. def __init__(self, net):
  307. super(GradNet, self).__init__()
  308. self.net = net
  309. self.weights = ParameterTuple(net.trainable_params())
  310. def construct(self, a, b, c):
  311. return grad_by_list(self.net, self.weights)(a, b, c)
  312. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  313. while_net = MyWhileNet()
  314. net = GradNet(while_net)
  315. idx = Tensor(np.array(0), dtype=ms.int32)
  316. end = Tensor(np.array(4), dtype=ms.int32)
  317. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  318. net(idx, end, x)
  319. def test_for_while_with_param_grad_normal():
  320. class MyWhileNet(nn.Cell):
  321. def __init__(self):
  322. super().__init__()
  323. self.max = P.ReduceMax()
  324. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  325. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  326. self.reduce = P.ReduceSum()
  327. self.start = Tensor(np.array(0), dtype=ms.int32)
  328. def construct(self, idx, end, x):
  329. out = x
  330. for _ in range(0, 2):
  331. idx = self.start
  332. while idx < end:
  333. out = out + self.param
  334. idx = idx + 1
  335. return out
  336. class GradNet(nn.Cell):
  337. def __init__(self, net):
  338. super(GradNet, self).__init__()
  339. self.net = net
  340. self.weights = ParameterTuple(net.trainable_params())
  341. def construct(self, a, b, c):
  342. return grad_by_list(self.net, self.weights)(a, b, c)
  343. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  344. while_net = MyWhileNet()
  345. net = GradNet(while_net)
  346. idx = Tensor(np.array(0), dtype=ms.int32)
  347. end = Tensor(np.array(4), dtype=ms.int32)
  348. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  349. net(idx, end, x)
  350. def test_while_with_param_basic_grad():
  351. class MyWhileNet(nn.Cell):
  352. def __init__(self):
  353. super().__init__()
  354. self.max = P.ReduceMax()
  355. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  356. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  357. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  358. def construct(self, idx, end, x):
  359. out = self.zero
  360. while idx < end:
  361. out = out + self.param
  362. idx = idx + 1
  363. return out + self.param
  364. class GradNet(nn.Cell):
  365. def __init__(self, net):
  366. super(GradNet, self).__init__()
  367. self.net = net
  368. self.weights = ParameterTuple(net.trainable_params())
  369. def construct(self, a, b, c):
  370. return grad_by_list(self.net, self.weights)(a, b, c)
  371. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  372. while_net = MyWhileNet()
  373. net = GradNet(while_net)
  374. idx = Tensor(np.array(0), dtype=ms.int32)
  375. end = Tensor(np.array(3), dtype=ms.int32)
  376. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  377. net(idx, end, x)
  378. def test_while_with_param_basic_grad_mul():
  379. class MyWhileNet(nn.Cell):
  380. def __init__(self):
  381. super().__init__()
  382. self.max = P.ReduceMax()
  383. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  384. self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32)
  385. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  386. def construct(self, idx, end, x):
  387. out = self.zero
  388. while idx < end:
  389. out = out * self.param
  390. idx = idx + 1
  391. return out + self.param
  392. class GradNet(nn.Cell):
  393. def __init__(self, net):
  394. super(GradNet, self).__init__()
  395. self.net = net
  396. self.weights = ParameterTuple(net.trainable_params())
  397. def construct(self, a, b, c):
  398. return grad_by_list(self.net, self.weights)(a, b, c)
  399. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  400. while_net = MyWhileNet()
  401. net = GradNet(while_net)
  402. idx = Tensor(np.array(0), dtype=ms.int32)
  403. end = Tensor(np.array(3), dtype=ms.int32)
  404. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  405. net(idx, end, x)
  406. def test_while_with_param_basic_grad_two():
  407. class MyWhileNet(nn.Cell):
  408. def __init__(self):
  409. super().__init__()
  410. self.max = P.ReduceMax()
  411. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  412. self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
  413. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  414. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  415. def construct(self, idx, end, x):
  416. out = self.zero
  417. while idx < end:
  418. out = out + self.param + self.weight
  419. idx = idx + 1
  420. return out + self.param
  421. class GradNet(nn.Cell):
  422. def __init__(self, net):
  423. super(GradNet, self).__init__()
  424. self.net = net
  425. self.weights = ParameterTuple(net.trainable_params())
  426. def construct(self, a, b, c):
  427. return grad_by_list(self.net, self.weights)(a, b, c)
  428. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  429. while_net = MyWhileNet()
  430. net = GradNet(while_net)
  431. idx = Tensor(np.array(0), dtype=ms.int32)
  432. end = Tensor(np.array(3), dtype=ms.int32)
  433. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  434. net(idx, end, x)
  435. def test_while_with_param_basic_grad_three():
  436. class MyWhileNet(nn.Cell):
  437. def __init__(self):
  438. super().__init__()
  439. self.max = P.ReduceMax()
  440. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  441. self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
  442. self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key")
  443. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  444. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  445. def construct(self, idx, end, x):
  446. out = self.zero
  447. while idx < end:
  448. out = out + self.param + self.weight + self.key
  449. idx = idx + 1
  450. return out + self.param
  451. class GradNet(nn.Cell):
  452. def __init__(self, net):
  453. super(GradNet, self).__init__()
  454. self.net = net
  455. self.weights = ParameterTuple(net.trainable_params())
  456. def construct(self, a, b, c):
  457. return grad_by_list(self.net, self.weights)(a, b, c)
  458. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  459. while_net = MyWhileNet()
  460. net = GradNet(while_net)
  461. idx = Tensor(np.array(0), dtype=ms.int32)
  462. end = Tensor(np.array(3), dtype=ms.int32)
  463. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  464. net(idx, end, x)
  465. def test_while_if_with_param_grad():
  466. class MyWhileNet(nn.Cell):
  467. def __init__(self):
  468. super().__init__()
  469. self.max = P.ReduceMax()
  470. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  471. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  472. self.t2 = Tensor(np.array(2), dtype=ms.float32)
  473. def construct(self, idx, end, x):
  474. out = self.zero
  475. while idx < end:
  476. if self.max(out) < self.max(x):
  477. out = out + self.param * 2
  478. else:
  479. out = out + self.param
  480. idx = idx + 1
  481. return out + self.param
  482. class GradNet(nn.Cell):
  483. def __init__(self, net):
  484. super(GradNet, self).__init__()
  485. self.net = net
  486. self.weights = ParameterTuple(net.trainable_params())
  487. def construct(self, a, b, c):
  488. return grad_by_list(self.net, self.weights)(a, b, c)
  489. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  490. while_net = MyWhileNet()
  491. net = GradNet(while_net)
  492. idx = Tensor(np.array(0), dtype=ms.int32)
  493. end = Tensor(np.array(3), dtype=ms.int32)
  494. x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
  495. net(idx, end, x)
  496. def test_while_with_param_grad_not_enter_while():
  497. class MyWhileNet(nn.Cell):
  498. def __init__(self):
  499. super().__init__()
  500. self.max = P.ReduceMax()
  501. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  502. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  503. def construct(self, idx, end, x):
  504. out = self.zero
  505. while idx < end:
  506. out = out + self.param * 3
  507. idx = idx + 1
  508. return out + self.param
  509. class GradNet(nn.Cell):
  510. def __init__(self, net):
  511. super(GradNet, self).__init__()
  512. self.net = net
  513. self.weights = ParameterTuple(net.trainable_params())
  514. def construct(self, a, b, c):
  515. return grad_by_list(self.net, self.weights)(a, b, c)
  516. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  517. while_net = MyWhileNet()
  518. net = GradNet(while_net)
  519. idx = Tensor(np.array(3), dtype=ms.int32)
  520. end = Tensor(np.array(0), dtype=ms.int32)
  521. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  522. net(idx, end, x)
  523. def test_with_param_if_by_if_forward():
  524. class MyIfByIfNet(nn.Cell):
  525. def __init__(self):
  526. super().__init__()
  527. self.max = P.ReduceMax()
  528. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  529. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  530. def construct(self, a, b, x):
  531. out = self.zero
  532. if a < b:
  533. out = out + x + self.param
  534. else:
  535. out = out + x
  536. if a == b:
  537. out = out + x*3 + self.param
  538. else:
  539. out = out + x*2
  540. return out
  541. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  542. if_net = MyIfByIfNet()
  543. net = if_net
  544. idx = Tensor(np.array(0), dtype=ms.int32)
  545. end = Tensor(np.array(4), dtype=ms.int32)
  546. x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
  547. net(idx, end, x)
  548. def test_with_param_if_by_if_grad_inputs():
  549. class MyIfByIfNet(nn.Cell):
  550. def __init__(self):
  551. super().__init__()
  552. self.max = P.ReduceMax()
  553. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  554. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  555. def construct(self, a, b, x):
  556. out = self.zero
  557. if a < b:
  558. out = out + x + self.param * 4
  559. if a == b:
  560. out = out + x*3 + self.param * 3
  561. return out
  562. class GradNet(nn.Cell):
  563. def __init__(self, net):
  564. super(GradNet, self).__init__()
  565. self.net = net
  566. def construct(self, *inputs):
  567. return grad_all(self.net)(*inputs)
  568. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  569. if_net = MyIfByIfNet()
  570. net = GradNet(if_net)
  571. idx = Tensor(np.array(0), dtype=ms.int32)
  572. end = Tensor(np.array(0), dtype=ms.int32)
  573. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  574. net(idx, end, x)
  575. def test_with_param_if_by_if_grad_parameter():
  576. class MyIfByIfNet(nn.Cell):
  577. def __init__(self):
  578. super().__init__()
  579. self.max = P.ReduceMax()
  580. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  581. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  582. def construct(self, a, b, x):
  583. out = self.zero
  584. if a < b:
  585. out = out + x + self.param * 2
  586. if a == b:
  587. out = out + x*3 + self.param
  588. return out
  589. class GradNet(nn.Cell):
  590. def __init__(self, net):
  591. super(GradNet, self).__init__()
  592. self.net = net
  593. self.weights = ParameterTuple(net.trainable_params())
  594. def construct(self, *inputs):
  595. return grad_by_list(self.net, self.weights)(*inputs)
  596. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  597. if_net = MyIfByIfNet()
  598. net = GradNet(if_net)
  599. idx = Tensor(np.array(0), dtype=ms.int32)
  600. end = Tensor(np.array(2), dtype=ms.int32)
  601. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  602. net(idx, end, x)
  603. def test_with_param_if_by_if_grad_param_excute_null():
  604. class MyIfByIfNet(nn.Cell):
  605. def __init__(self):
  606. super().__init__()
  607. self.max = P.ReduceMax()
  608. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  609. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  610. def construct(self, a, b, x):
  611. out = self.zero
  612. if a < b:
  613. out = out + x + self.param * 2
  614. return out
  615. class GradNet(nn.Cell):
  616. def __init__(self, net):
  617. super(GradNet, self).__init__()
  618. self.net = net
  619. self.weights = ParameterTuple(net.trainable_params())
  620. def construct(self, *inputs):
  621. return grad_by_list(self.net, self.weights)(*inputs)
  622. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  623. if_net = MyIfByIfNet()
  624. net = GradNet(if_net)
  625. idx = Tensor(np.array(4), dtype=ms.int32)
  626. end = Tensor(np.array(0), dtype=ms.int32)
  627. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  628. net(idx, end, x)
  629. def test_if_by_if_return_inside_grad():
  630. class MyIfByIfNet(nn.Cell):
  631. def __init__(self):
  632. super().__init__()
  633. self.max = P.ReduceMax()
  634. self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
  635. self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
  636. def construct(self, a, b, x):
  637. out = self.zero
  638. if a < b:
  639. return out + x + self.param
  640. if a == b:
  641. return out + self.param * 2
  642. return out + self.param * 3
  643. class GradNet(nn.Cell):
  644. def __init__(self, net):
  645. super(GradNet, self).__init__()
  646. self.net = net
  647. self.weights = ParameterTuple(net.trainable_params())
  648. def construct(self, *inputs):
  649. return grad_by_list(self.net, self.weights)(*inputs)
  650. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  651. if_net = MyIfByIfNet()
  652. net = GradNet(if_net)
  653. idx = Tensor(np.array(1), dtype=ms.int32)
  654. end = Tensor(np.array(0), dtype=ms.int32)
  655. x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
  656. net(idx, end, x)
  657. def test_if_by_if_forward():
  658. class MyIfByIfNet(nn.Cell):
  659. def __init__(self):
  660. super().__init__()
  661. self.add = P.TensorAdd()
  662. self.sub = P.Sub()
  663. self.mul = P.Mul()
  664. self.div = P.RealDiv()
  665. def construct(self, a, b, x):
  666. if a < b:
  667. a = self.add(a, b)
  668. else:
  669. a = self.sub(a, b)
  670. if a == x:
  671. a = self.mul(a, b)
  672. else:
  673. a = self.div(a, b)
  674. if b == x:
  675. b = self.add(a, b)
  676. else:
  677. b = self.add(a, x)
  678. a = a * b
  679. out = a + b + x
  680. return out
  681. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  682. if_net = MyIfByIfNet()
  683. net = if_net
  684. idx = Tensor(np.array(2), dtype=ms.float32)
  685. end = Tensor(np.array(3), dtype=ms.float32)
  686. x = Tensor(np.array(4), dtype=ms.float32)
  687. net(idx, end, x)
  688. def test_if_by_if_forward_control_tuple_switch():
  689. """tuple_get from swtich op will generate new switch inside to eliminate tuple_get"""
  690. class Branch3Net(nn.Cell):
  691. def __init__(self):
  692. super().__init__()
  693. self.add = P.TensorAdd()
  694. self.sub = P.Sub()
  695. self.mul = P.Mul()
  696. self.div = P.RealDiv()
  697. def construct(self, a, b, x):
  698. if b == x:
  699. b = self.add(a, b)
  700. else:
  701. b = self.add(a, x)
  702. return a, b, x
  703. class Branch2Net(nn.Cell):
  704. def __init__(self):
  705. super().__init__()
  706. self.add = P.TensorAdd()
  707. self.sub = P.Sub()
  708. self.mul = P.Mul()
  709. self.div = P.RealDiv()
  710. self.net = Branch3Net()
  711. def construct(self, a, b, x):
  712. if a == x:
  713. a = self.mul(a, b)
  714. else:
  715. a = self.div(a, b)
  716. return self.net(a, b, x)
  717. class MyIfByIfNet(nn.Cell):
  718. def __init__(self):
  719. super().__init__()
  720. self.add = P.TensorAdd()
  721. self.sub = P.Sub()
  722. self.mul = P.Mul()
  723. self.div = P.RealDiv()
  724. self.net = Branch2Net()
  725. def construct(self, a, b, x):
  726. if a < b:
  727. a = self.add(a, b)
  728. else:
  729. a = self.sub(a, b)
  730. a, b, x = self.net(a, b, x)
  731. a = a * b
  732. out = a + b + x
  733. return out
  734. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  735. if_net = MyIfByIfNet()
  736. net = if_net
  737. idx = Tensor(np.array(2), dtype=ms.float32)
  738. end = Tensor(np.array(3), dtype=ms.float32)
  739. x = Tensor(np.array(0), dtype=ms.float32)
  740. net(idx, end, x)
  741. def test_if_by_if_forward_control_inside_net():
  742. class Branch3Net(nn.Cell):
  743. def __init__(self):
  744. super().__init__()
  745. self.add = P.TensorAdd()
  746. self.sub = P.Sub()
  747. self.mul = P.Mul()
  748. self.div = P.RealDiv()
  749. def construct(self, a, b, x):
  750. if b == x:
  751. b = self.add(a, b)
  752. else:
  753. b = self.add(a, x)
  754. a = a * b
  755. out = a + b + x
  756. return out
  757. class Branch2Net(nn.Cell):
  758. def __init__(self):
  759. super().__init__()
  760. self.add = P.TensorAdd()
  761. self.sub = P.Sub()
  762. self.mul = P.Mul()
  763. self.div = P.RealDiv()
  764. self.net = Branch3Net()
  765. def construct(self, a, b, x):
  766. if a == x:
  767. a = self.mul(a, b)
  768. else:
  769. a = self.div(a, b)
  770. return self.net(a, b, x)
  771. class MyIfByIfNet(nn.Cell):
  772. def __init__(self):
  773. super().__init__()
  774. self.add = P.TensorAdd()
  775. self.sub = P.Sub()
  776. self.mul = P.Mul()
  777. self.div = P.RealDiv()
  778. self.net = Branch2Net()
  779. def construct(self, a, b, x):
  780. if a < b:
  781. a = self.add(a, b)
  782. else:
  783. a = self.sub(a, b)
  784. out = self.net(a, b, x)
  785. return out
  786. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  787. if_net = MyIfByIfNet()
  788. net = if_net
  789. idx = Tensor(np.array(2), dtype=ms.float32)
  790. end = Tensor(np.array(3), dtype=ms.float32)
  791. x = Tensor(np.array(0), dtype=ms.float32)
  792. net(idx, end, x)
  793. def test_if_by_if_forward_use_namespace():
  794. class MyIfByIfNet(nn.Cell):
  795. def __init__(self):
  796. super().__init__()
  797. self.add = P.TensorAdd()
  798. self.sub = P.Sub()
  799. self.mul = P.Mul()
  800. self.div = P.RealDiv()
  801. def construct(self, a, b, x):
  802. if a < b:
  803. a = P.TensorAdd()(a, b)
  804. else:
  805. a = P.Sub()(a, b)
  806. if a == x:
  807. a = P.Mul()(a, b)
  808. else:
  809. a = P.RealDiv()(a, b)
  810. if b == x:
  811. b = P.TensorAdd()(a, b)
  812. else:
  813. b = P.TensorAdd()(a, x)
  814. a = a * b
  815. out = a + b + x
  816. return out
  817. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  818. if_net = MyIfByIfNet()
  819. net = if_net
  820. idx = Tensor(np.array(2), dtype=ms.float32)
  821. end = Tensor(np.array(3), dtype=ms.float32)
  822. x = Tensor(np.array(0), dtype=ms.float32)
  823. net(idx, end, x)
  824. def test_if_by_if_forward_use_global_op():
  825. class MyIfByIfNet(nn.Cell):
  826. def __init__(self):
  827. super().__init__()
  828. self.add = P.TensorAdd()
  829. self.sub = P.Sub()
  830. self.mul = P.Mul()
  831. self.div = P.RealDiv()
  832. def construct(self, a, b, x):
  833. add = P.TensorAdd()
  834. sub = P.Sub()
  835. mul = P.Mul()
  836. div = P.RealDiv()
  837. if a < b:
  838. a = add(a, b)
  839. else:
  840. a = sub(a, b)
  841. if a == x:
  842. a = mul(a, b)
  843. else:
  844. a = div(a, b)
  845. if b == x:
  846. b = add(a, b)
  847. else:
  848. b = add(a, x)
  849. a = a * b
  850. out = a + b + x
  851. return out
  852. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  853. if_net = MyIfByIfNet()
  854. net = if_net
  855. idx = Tensor(np.array(2), dtype=ms.float32)
  856. end = Tensor(np.array(3), dtype=ms.float32)
  857. x = Tensor(np.array(0), dtype=ms.float32)
  858. net(idx, end, x)
  859. def test_for_with_if_by_if_forward():
  860. class MyIfByIfNet(nn.Cell):
  861. def __init__(self):
  862. super().__init__()
  863. self.add = P.TensorAdd()
  864. self.sub = P.Sub()
  865. def construct(self, a, b, x):
  866. for _ in range(0, 4):
  867. if a < b:
  868. a = self.add(a, b)
  869. else:
  870. b = self.sub(b, x)
  871. a = a * b
  872. out = a + b + x
  873. return out
  874. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  875. if_net = MyIfByIfNet()
  876. net = if_net
  877. idx = Tensor(np.array(2), dtype=ms.float32)
  878. end = Tensor(np.array(3), dtype=ms.float32)
  879. x = Tensor(np.array(0), dtype=ms.float32)
  880. net(idx, end, x)
  881. def test_for_with_if_by_if_forward_namespace():
  882. class MyIfByIfNet(nn.Cell):
  883. def __init__(self):
  884. super().__init__()
  885. self.add = P.TensorAdd()
  886. self.sub = P.Sub()
  887. self.mul = P.Mul()
  888. self.div = P.RealDiv()
  889. def construct(self, a, b, x):
  890. for _ in range(0, 6):
  891. if a < b:
  892. a = P.TensorAdd()(a, b)
  893. else:
  894. b = P.Sub()(b, x)
  895. a = a * b
  896. out = a + b + x
  897. return out
  898. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  899. if_net = MyIfByIfNet()
  900. net = if_net
  901. idx = Tensor(np.array(2), dtype=ms.float32)
  902. end = Tensor(np.array(3), dtype=ms.float32)
  903. x = Tensor(np.array(0), dtype=ms.float32)
  904. net(idx, end, x)
  905. def test_if_by_if_forward_const_branch_inner():
  906. class MyIfByIfNet(nn.Cell):
  907. def __init__(self):
  908. super().__init__()
  909. self.add = P.TensorAdd()
  910. self.sub = P.Sub()
  911. self.mul = P.Mul()
  912. self.div = P.RealDiv()
  913. def construct(self, a, b, x):
  914. add = P.TensorAdd()
  915. sub = P.Sub()
  916. mul = P.Mul()
  917. div = P.RealDiv()
  918. if a < b:
  919. a = add(a, b)
  920. else:
  921. a = sub(a, b)
  922. if 2 > 1:
  923. a = mul(a, b)
  924. else:
  925. a = div(a, b)
  926. if b == x:
  927. b = add(a, b)
  928. else:
  929. b = add(a, x)
  930. a = a * b
  931. out = a + b + x
  932. return out
  933. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  934. if_net = MyIfByIfNet()
  935. net = if_net
  936. idx = Tensor(np.array(2), dtype=ms.float32)
  937. end = Tensor(np.array(3), dtype=ms.float32)
  938. x = Tensor(np.array(0), dtype=ms.float32)
  939. net(idx, end, x)
  940. def test_if_by_if_forward_all_const_branch():
  941. class MyIfByIfNet(nn.Cell):
  942. def __init__(self):
  943. super().__init__()
  944. self.add = P.TensorAdd()
  945. self.sub = P.Sub()
  946. self.mul = P.Mul()
  947. self.div = P.RealDiv()
  948. def construct(self, a, b, x):
  949. add = P.TensorAdd()
  950. sub = P.Sub()
  951. mul = P.Mul()
  952. div = P.RealDiv()
  953. if 2 < 12:
  954. a = add(a, b)
  955. else:
  956. a = sub(a, b)
  957. if 2 > 1:
  958. a = mul(a, b)
  959. else:
  960. a = div(a, b)
  961. if 2 == 1:
  962. b = add(a, b)
  963. else:
  964. b = add(a, x)
  965. a = a * b
  966. out = a + b + x
  967. return out
  968. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  969. if_net = MyIfByIfNet()
  970. net = if_net
  971. idx = Tensor(np.array(2), dtype=ms.float32)
  972. end = Tensor(np.array(3), dtype=ms.float32)
  973. x = Tensor(np.array(0), dtype=ms.float32)
  974. net(idx, end, x)