You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_control_ops.py 30 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test control ops """
  16. import os
  17. import numpy as np
  18. import pytest
  19. import mindspore as ms
  20. from mindspore import Tensor
  21. from mindspore import context
  22. from mindspore import nn
  23. from mindspore.common import dtype as mstype
  24. from mindspore.ops import composite as C
  25. from mindspore.ops import functional as F
  26. from mindspore.ops import operations as P
  27. from mindspore.common.parameter import Parameter, ParameterTuple
  28. from mindspore.common import ms_function
  29. context.set_context(mode=context.GRAPH_MODE)
  30. grad_by_list = C.GradOperation(get_by_list=True)
  31. grad_all = C.GradOperation(get_all=True)
  32. grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
  33. def cond_data_test(x_init, y_init):
  34. class Net(nn.Cell):
  35. def __init__(self):
  36. """"""
  37. super(Net, self).__init__()
  38. self.square = P.Square()
  39. self.add = P.Add()
  40. self.value = Tensor(3, dtype=ms.float32)
  41. self.switch = P.GeSwitch()
  42. self.merge = P.Merge()
  43. self.less = P.Less()
  44. def construct(self, x, y):
  45. cond = self.less(x, y)
  46. st1, _ = self.switch(x, cond)
  47. st2, _ = self.switch(y, cond)
  48. add_ret = self.add(st1, st2)
  49. _, sf3 = self.switch(self.value, cond)
  50. sq_ret = self.square(sf3)
  51. ret = self.merge((add_ret, sq_ret))
  52. return ret[0]
  53. x = Tensor(x_init, dtype=ms.float32)
  54. y = Tensor(y_init, dtype=ms.float32)
  55. net = Net()
  56. output = net(x, y)
  57. return output
  58. def test_cond_data_true():
  59. output = cond_data_test(3, 8)
  60. print("test_cond_data_true:", output)
  61. def test_cond_data_false():
  62. output = cond_data_test(8, 3)
  63. print("test_cond_data_false:", output)
  64. def if_compile_test(x_init, y_init):
  65. class Net(nn.Cell):
  66. def __init__(self):
  67. """"""
  68. super(Net, self).__init__()
  69. self.square = P.Square()
  70. self.add = P.Add()
  71. self.value = Tensor(3, dtype=ms.float32)
  72. self.switch = P.GeSwitch()
  73. self.merge = P.Merge()
  74. self.less = P.Less()
  75. def construct(self, x, y):
  76. cond = self.less(x, y)
  77. ret = self.value
  78. if cond:
  79. ret = self.add(x, ret)
  80. ret = self.add(y, ret)
  81. else:
  82. ret = self.square(self.value)
  83. return ret
  84. x = Tensor(x_init, dtype=ms.float32)
  85. y = Tensor(y_init, dtype=ms.float32)
  86. net = Net()
  87. output = net(x, y)
  88. return output
  89. def test_if_none():
  90. class Net(nn.Cell):
  91. def __init__(self, z: None):
  92. """"""
  93. super(Net, self).__init__()
  94. self.z = z
  95. def construct(self, x, y):
  96. if self.z:
  97. ret = x
  98. else:
  99. ret = y
  100. return ret
  101. x = Tensor(np.ones([6, 8, 10], np.int32))
  102. y = Tensor(np.zeros([3, 4, 5], np.int32))
  103. z = None
  104. net = Net(z)
  105. assert np.all(net(x, y).asnumpy() == y.asnumpy())
  106. def test_if_str_is_not_none_right():
  107. class Net(nn.Cell):
  108. def __init__(self, z: str):
  109. """"""
  110. super(Net, self).__init__()
  111. self.z = z
  112. def construct(self, x, y):
  113. if self.z is None:
  114. ret = x
  115. else:
  116. ret = y
  117. return ret
  118. x = Tensor(np.ones([6, 8, 10], np.int32))
  119. y = Tensor(np.zeros([3, 4, 5], np.int32))
  120. z = "ok"
  121. net = Net(z)
  122. assert np.all(net(x, y).asnumpy() == y.asnumpy())
  123. def test_if_str_is_not_none_left():
  124. class Net(nn.Cell):
  125. def __init__(self, z: str):
  126. """"""
  127. super(Net, self).__init__()
  128. self.z = z
  129. def construct(self, x, y):
  130. if self.z is None:
  131. ret = x
  132. else:
  133. ret = y
  134. return ret
  135. x = Tensor(np.ones([6, 8, 10], np.int32))
  136. y = Tensor(np.zeros([3, 4, 5], np.int32))
  137. z = "ok"
  138. net = Net(z)
  139. assert np.all(net(x, y).asnumpy() == y.asnumpy())
  140. def test_if_none_equal_none():
  141. class Net(nn.Cell):
  142. def __init__(self, z: None):
  143. """"""
  144. super(Net, self).__init__()
  145. self.z = z
  146. def construct(self, x, y):
  147. if self.z is None:
  148. ret = x
  149. else:
  150. ret = y
  151. return ret
  152. x = Tensor(np.ones([6, 8, 10], np.int32))
  153. y = Tensor(np.zeros([3, 4, 5], np.int32))
  154. z = None
  155. net = Net(z)
  156. assert np.all(net(x, y).asnumpy() == x.asnumpy())
  157. def test_if_str_is_null():
  158. class Net(nn.Cell):
  159. def __init__(self, z: str):
  160. """"""
  161. super(Net, self).__init__()
  162. self.z = z
  163. def construct(self, x, y):
  164. if self.z:
  165. ret = x
  166. else:
  167. ret = y
  168. return ret
  169. x = Tensor(np.ones([6, 8, 10], np.int32))
  170. y = Tensor(np.zeros([3, 4, 5], np.int32))
  171. z = ""
  172. net = Net(z)
  173. assert np.all(net(x, y).asnumpy() == y.asnumpy())
  174. def test_if_str_is_true():
  175. class Net(nn.Cell):
  176. def __init__(self, z: str):
  177. """"""
  178. super(Net, self).__init__()
  179. self.z = z
  180. def construct(self, x, y):
  181. if self.z:
  182. ret = x
  183. else:
  184. ret = y
  185. return ret
  186. x = Tensor(np.ones([6, 9, 10], np.int32))
  187. y = Tensor(np.zeros([3, 4, 5], np.int32))
  188. z = "ok"
  189. net = Net(z)
  190. assert np.all(net(x, y).asnumpy() == x.asnumpy())
  191. def test_if_str_equal():
  192. class Net(nn.Cell):
  193. def __init__(self, z: str):
  194. """"""
  195. super(Net, self).__init__()
  196. self.z = z
  197. def construct(self, x, y):
  198. if self.z == "ok":
  199. ret = x
  200. else:
  201. ret = y
  202. return ret
  203. x = Tensor(np.ones([6, 8, 10], np.int32))
  204. y = Tensor(np.zeros([3, 4, 5], np.int32))
  205. z = "ok"
  206. net = Net(z)
  207. assert np.all(net(x, y).asnumpy() == x.asnumpy())
  208. def test_if_tuple_is_null():
  209. class Net(nn.Cell):
  210. def __init__(self, z: tuple):
  211. """"""
  212. super(Net, self).__init__()
  213. self.z = z
  214. def construct(self, x, y):
  215. if self.z:
  216. ret = x
  217. else:
  218. ret = y
  219. return ret
  220. x = Tensor(np.ones([6, 8, 10], np.int32))
  221. y = Tensor(np.zeros([3, 4, 5], np.int32))
  222. z = ()
  223. net = Net(z)
  224. assert np.all(net(x, y).asnumpy() == y.asnumpy())
  225. def test_if_tuple_is_not_null():
  226. class Net(nn.Cell):
  227. def __init__(self, z: tuple):
  228. """"""
  229. super(Net, self).__init__()
  230. self.z = z
  231. def construct(self, x, y):
  232. if self.z:
  233. ret = x
  234. else:
  235. ret = y
  236. return ret
  237. x = Tensor(np.ones([6, 8, 10], np.int32))
  238. y = Tensor(np.zeros([3, 4, 5], np.int32))
  239. z = (1, 2, 3)
  240. net = Net(z)
  241. assert np.all(net(x, y).asnumpy() == x.asnumpy())
  242. def test_if_dict_is_null():
  243. class Net(nn.Cell):
  244. def __init__(self, z: dict):
  245. """"""
  246. super(Net, self).__init__()
  247. self.z = z
  248. def construct(self, x, y):
  249. if self.z:
  250. ret = x
  251. else:
  252. ret = y
  253. return ret
  254. x = Tensor(np.ones([6, 8, 10], np.int32))
  255. y = Tensor(np.zeros([3, 4, 5], np.int32))
  256. z = {}
  257. net = Net(z)
  258. assert np.all(net(x, y).asnumpy() == y.asnumpy())
  259. def test_if_dict_is_not_null():
  260. class Net(nn.Cell):
  261. def __init__(self, z: dict):
  262. """"""
  263. super(Net, self).__init__()
  264. self.z = z
  265. def construct(self, x, y):
  266. if self.z:
  267. ret = x
  268. else:
  269. ret = y
  270. return ret
  271. x = Tensor(np.ones([6, 8, 10], np.int32))
  272. y = Tensor(np.zeros([3, 4, 5], np.int32))
  273. z = {"one": 1, "two": 2}
  274. net = Net(z)
  275. assert np.all(net(x, y).asnumpy() == x.asnumpy())
  276. def test_if_else_assign():
  277. class Net(nn.Cell):
  278. def __init__(self, m: list):
  279. """"""
  280. super(Net, self).__init__()
  281. self.m = m
  282. self.n = [4, 5, 6]
  283. def construct(self, x, y):
  284. exp_1 = self.m if self.m else self.n
  285. exp_2 = self.m if exp_1 == self.n else self.n
  286. if exp_2 == self.m:
  287. if self.m:
  288. ret = x
  289. else:
  290. ret = y
  291. else:
  292. if self.m:
  293. ret = x
  294. else:
  295. ret = y
  296. return ret
  297. x = Tensor(np.ones([6, 8, 10], np.int32))
  298. y = Tensor(np.zeros([3, 4, 5], np.int32))
  299. z = [1, 2]
  300. net = Net(z)
  301. assert np.all(net(x, y).asnumpy() == x.asnumpy())
  302. def test_if_compile_true():
  303. output = if_compile_test(3, 8)
  304. print("test_if_compile_true:", output)
  305. def test_if_compile_false():
  306. output = if_compile_test(8, 3)
  307. print("test_if_compile_false:", output)
  308. def test_switch_layer():
  309. class Layer1(nn.Cell):
  310. def __init__(self):
  311. super(Layer1, self).__init__()
  312. self.z1 = Parameter(
  313. Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1')
  314. def construct(self, x):
  315. return x * self.z1
  316. class Layer2(nn.Cell):
  317. def __init__(self):
  318. super(Layer2, self).__init__()
  319. self.z2 = Parameter(
  320. Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2')
  321. def construct(self, x):
  322. return x * self.z2
  323. class SwitchLayerCell(nn.Cell):
  324. def __init__(self):
  325. super(SwitchLayerCell, self).__init__()
  326. self.layers = (Layer1(), Layer2())
  327. self.z3 = Parameter(
  328. Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
  329. def construct(self, index, x):
  330. ret = F.switch_layer(index, self.layers)(x) * self.z3
  331. return ret
  332. index = Tensor(0, dtype=mstype.int32)
  333. net = SwitchLayerCell()
  334. net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
  335. grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
  336. Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
  337. grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
  338. def test_index_to_switch_layer():
  339. class Layer1(nn.Cell):
  340. def __init__(self):
  341. super(Layer1, self).__init__()
  342. self.z1 = Parameter(
  343. Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1')
  344. def construct(self, x):
  345. return x * self.z1
  346. class Layer2(nn.Cell):
  347. def __init__(self):
  348. super(Layer2, self).__init__()
  349. self.z2 = Parameter(
  350. Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2')
  351. def construct(self, x):
  352. return x * self.z2
  353. class SwitchLayerCell(nn.Cell):
  354. def __init__(self):
  355. super(SwitchLayerCell, self).__init__()
  356. self.layers = (Layer1(), Layer2())
  357. self.z3 = Parameter(
  358. Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
  359. def construct(self, index, x):
  360. ret = self.layers[index](x) * self.z3
  361. return ret
  362. index = Tensor(0, dtype=mstype.int32)
  363. net = SwitchLayerCell()
  364. net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
  365. grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
  366. Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
  367. grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
  368. def test_parser_switch_layer_switch_in_bprop():
  369. class OneInputBprop(nn.Cell):
  370. def __init__(self, funcs):
  371. super(OneInputBprop, self).__init__()
  372. self.op = P.ReLU()
  373. self.funcs = funcs
  374. def construct(self, i, x):
  375. return self.op(x)
  376. def bprop(self, i, x, out, dout):
  377. return i, self.funcs[i](x, dout)
  378. class Add(nn.Cell):
  379. def __init__(self):
  380. super().__init__()
  381. self.op = P.Add()
  382. def construct(self, x, y):
  383. return self.op(x, y)
  384. class Mul(nn.Cell):
  385. def __init__(self):
  386. super().__init__()
  387. self.op = P.Mul()
  388. def construct(self, x, y):
  389. return self.op(x, y)
  390. func1 = Add()
  391. func2 = Mul()
  392. funcs = (func1, func2)
  393. net = OneInputBprop(funcs)
  394. input1 = Tensor(np.ones([2, 2]).astype(np.float32))
  395. grad = Tensor(np.random.randn(2, 2).astype(np.float32))
  396. i = Tensor(1, mstype.int32)
  397. grad_net = grad_all_with_sens(net)
  398. grad_net(i, input1, grad)
  399. def test_parser_switch_layer_inputs_tuple():
  400. class TwoInputTupleFinalNet(nn.Cell):
  401. def __init__(self, funcs):
  402. super().__init__()
  403. self.funcs = funcs
  404. def construct(self, i, inputa, inputb):
  405. inputs = (inputa, inputb)
  406. x = self.funcs[i](inputs)
  407. return x
  408. class Add(nn.Cell):
  409. def __init__(self):
  410. super().__init__()
  411. self.op = P.Add()
  412. def construct(self, x):
  413. y = self.op(x[0], x[1])
  414. return self.op(x[0], y)
  415. class Mul(nn.Cell):
  416. def __init__(self):
  417. super().__init__()
  418. self.op = P.Mul()
  419. def construct(self, x):
  420. y = self.op(x[0], x[1])
  421. return self.op(x[0], y)
  422. func1 = Add()
  423. func2 = Mul()
  424. funcs = (func1, func2)
  425. net = TwoInputTupleFinalNet(funcs)
  426. input1 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
  427. input2 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
  428. i = Tensor(1, mstype.int32)
  429. grad = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
  430. back_net = grad_all_with_sens(net)
  431. back_out = back_net(i, input1, input2, grad)
  432. def test_switch_layer_with_single_prim():
  433. class SwitchLayerCell(nn.Cell):
  434. def __init__(self):
  435. super(SwitchLayerCell, self).__init__()
  436. self.layers = (nn.ReLU(), nn.ReLU())
  437. self.z3 = Parameter(
  438. Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
  439. def construct(self, index, x):
  440. ret = self.layers[index](x) * self.z3
  441. return ret
  442. index = Tensor(0, dtype=mstype.int32)
  443. net = SwitchLayerCell()
  444. net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
  445. grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
  446. Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
  447. grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
  448. def test_switch_layer_env_eliminate():
  449. class Net(nn.Cell):
  450. def __init__(self):
  451. super(Net, self).__init__()
  452. self.conv = nn.Conv2d(1, 1, 3, pad_mode='same')
  453. self.conv2 = nn.Conv2d(1, 1, 5, pad_mode='same')
  454. self.funs = (self.conv, self.conv2)
  455. def construct(self, x, index):
  456. x = self.funs[index](x)
  457. return x
  458. class NetGrad(nn.Cell):
  459. def __init__(self, net):
  460. super(NetGrad, self).__init__()
  461. self.grad_op = C.GradOperation(get_by_list=True, sens_param=False)
  462. self.net = net
  463. self.weights = ParameterTuple(self.net.trainable_params())
  464. def construct(self, x, index):
  465. weights = self.weights
  466. grad = self.grad_op(self.net, weights)(x, index)
  467. return grad
  468. net = Net()
  469. net2 = NetGrad(net)
  470. x = Tensor(np.ones((3, 1, 12, 12)), ms.float32)
  471. i = Tensor(1, ms.int32)
  472. net2(x, i)
  473. def test_switch_layer_single_layer():
  474. class Net(nn.Cell):
  475. def __init__(self):
  476. super(Net, self).__init__()
  477. self.conv = nn.Conv2d(1, 1, 3, pad_mode='same')
  478. self.funs = (self.conv,)
  479. def construct(self, x, index):
  480. x = self.funs[index](x)
  481. return x
  482. class NetGrad(nn.Cell):
  483. def __init__(self, net):
  484. super(NetGrad, self).__init__()
  485. self.grad_op = C.GradOperation(get_by_list=True, sens_param=False)
  486. self.net = net
  487. self.weights = ParameterTuple(self.net.trainable_params())
  488. def construct(self, x, index):
  489. weights = self.weights
  490. grad = self.grad_op(self.net, weights)(x, index)
  491. return grad
  492. net = Net()
  493. net2 = NetGrad(net)
  494. x = Tensor(np.ones((3, 1, 12, 12)), ms.float32)
  495. i = Tensor(1, ms.int32)
  496. net2(x, i)
  497. def test_if_nested_compile():
  498. class Net(nn.Cell):
  499. def __init__(self, auto_prefix=True):
  500. super().__init__(auto_prefix=auto_prefix)
  501. self.squre = P.Square()
  502. self.value = Tensor(3, dtype=ms.float32)
  503. def construct(self, x, y):
  504. res = self.value
  505. if x <= y:
  506. res = x + res
  507. res = y + res
  508. else:
  509. if x == y:
  510. res = self.squre(self.value * y)
  511. else:
  512. res = self.squre(self.value)
  513. return res
  514. x = Tensor(1.0, dtype=ms.float32)
  515. y = Tensor(2.0, dtype=ms.float32)
  516. net = Net()
  517. net(x, y)
  518. def test_if_inside_for():
  519. class Net(nn.Cell):
  520. def __init__(self, auto_prefix=True):
  521. super().__init__(auto_prefix=auto_prefix)
  522. self.squre = P.Square()
  523. self.value = Tensor(3, dtype=ms.float32)
  524. self.count = 4
  525. def construct(self, x, y):
  526. res = 0
  527. for i in range(self.count):
  528. if i == x:
  529. res = res + x
  530. else:
  531. res = res - y
  532. return res
  533. c1 = Tensor(1, dtype=ms.int32)
  534. c2 = Tensor(1, dtype=ms.int32)
  535. net = Net()
  536. net(c1, c2)
  537. def test_while_in_while():
  538. c1 = Tensor(1, dtype=ms.int32)
  539. c2 = Tensor(2, dtype=ms.int32)
  540. c3 = Tensor(3, dtype=ms.int32)
  541. c4 = Tensor(4, dtype=ms.int32)
  542. @ms_function
  543. def while_in_while(x, y, z, u):
  544. out = c4
  545. while x < y:
  546. z = c4 + c4
  547. while z < y:
  548. z = z + 1
  549. out = out + 1
  550. x = x + 1
  551. out = out + 3
  552. return out
  553. while_in_while(c1, c2, c3, c4)
  554. def test_tensor_cond():
  555. class Net(nn.Cell):
  556. def __init__(self):
  557. super(Net, self).__init__()
  558. self.t = Tensor(np.array(0, np.bool))
  559. self.t1 = Tensor(np.array([True], np.bool))
  560. def construct(self, x, y):
  561. t = 0
  562. if self.t:
  563. t = t - x * y
  564. else:
  565. t = t - x / y
  566. if self.t1:
  567. t = t + x / y
  568. else:
  569. t = t + x * y
  570. return t
  571. x = Tensor(np.ones([6, 8, 10], np.int32))
  572. y = Tensor(np.ones([6, 8, 10], np.int32))
  573. net = Net()
  574. out = net(x, y)
  575. def test_tensor_cond_exception():
  576. class Net(nn.Cell):
  577. def __init__(self):
  578. super(Net, self).__init__()
  579. self.t = Tensor(np.array([True, False], np.bool))
  580. def construct(self, x, y):
  581. t = 0
  582. if self.t:
  583. t = t - x * y
  584. else:
  585. t = t - x / y
  586. return t
  587. x = Tensor(np.ones([6, 8, 10], np.int32))
  588. y = Tensor(np.ones([6, 8, 10], np.int32))
  589. net = Net()
  590. with pytest.raises(ValueError):
  591. out = net(x, y)
  592. def test_while_scalar():
  593. class Net(nn.Cell):
  594. def __init__(self):
  595. super(Net, self).__init__()
  596. self.x = 10
  597. def construct(self, x, y):
  598. i = 0
  599. t = 0
  600. while (i < 10):
  601. t = t + x + y
  602. i = i + 1
  603. return t
  604. net = Net()
  605. x = Tensor(np.ones([6, 8, 10], np.int32))
  606. y = Tensor(np.ones([6, 8, 10], np.int32))
  607. out = net(x, y)
  608. def test_while_with_weight_in_condition():
  609. class Net(nn.Cell):
  610. def __init__(self):
  611. super(Net, self).__init__()
  612. self.loop = Parameter(Tensor(1, dtype=ms.float32), name="loop")
  613. def construct(self, x):
  614. while self.loop < 5:
  615. self.loop += 1
  616. x += 1
  617. return x
  618. net = Net()
  619. x = Tensor(-1, dtype=ms.float32)
  620. grad_all(net)(x)
  621. def test_mixed_precision_cast():
  622. x = Tensor(np.ones([2, 3], dtype=np.float32))
  623. z = F.mixed_precision_cast(mstype.float16, x)
  624. assert z.dtype == mstype.float16
  625. def test_while_add():
  626. class Net(nn.Cell):
  627. def __init__(self, data):
  628. super(Net, self).__init__()
  629. self.start = Tensor(0, dtype=mstype.int32)
  630. self.end = Tensor(2, dtype=mstype.int32)
  631. self.out = Tensor(np.zeros([2, 3], dtype=np.float32))
  632. self.add = P.Add()
  633. def construct(self, inputs):
  634. idx = self.start
  635. end = self.end
  636. out = self.out
  637. while idx < end:
  638. xi = inputs[idx, :, :]
  639. out = self.add(out, xi)
  640. idx = idx + 1
  641. return out
  642. x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32))
  643. net = Net(x)
  644. net(x)
  645. def test_tensor_all_construct_lack_branch():
  646. class NetConditionLackBranch(nn.Cell):
  647. def __init__(self):
  648. super(NetConditionLackBranch, self).__init__()
  649. self.logicaland = P.LogicalAnd()
  650. self.logicalor = P.LogicalOr()
  651. def construct(self, input1, input2):
  652. if input1.all():
  653. return self.logicaland(input1, input2)
  654. while input1.any():
  655. return self.logicalor(input1, input2)
  656. # NOTICE: here missing return statement, default return None
  657. input_np_1 = np.random.choice([True], size=(2, 3, 4, 5))
  658. input_tensor_1 = Tensor(input_np_1)
  659. input_np_2 = np.random.choice([True, False], size=(2, 3, 4, 5))
  660. input_tensor_2 = Tensor(input_np_2)
  661. net = NetConditionLackBranch()
  662. with pytest.raises(Exception):
  663. net(input_tensor_1, input_tensor_2)
  664. def test_parser_switch_layer_func_primitive():
  665. class FinalNet(nn.Cell):
  666. def __init__(self, funcs):
  667. super().__init__()
  668. self.funcs = funcs
  669. def construct(self, i, input1):
  670. x = self.funcs[i](input1)
  671. return x
  672. func1 = P.ReLU()
  673. func2 = P.Softmax()
  674. funcs = (func1, func2)
  675. net = FinalNet(funcs)
  676. input1 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
  677. i = Tensor(1, mstype.int32)
  678. with pytest.raises(ValueError):
  679. net(i, input1)
  680. def test_switch_layer_shape_join_failed():
  681. class AddFuncNet(nn.Cell):
  682. def __init__(self, funcs, new_func):
  683. super(AddFuncNet, self).__init__()
  684. self.funcs = funcs
  685. self.new_func = new_func
  686. def construct(self, i, inputs):
  687. final_funcs = self.funcs + (self.new_func,)
  688. x = final_funcs[i](inputs)
  689. return x
  690. class ReLUTuple(nn.Cell):
  691. def __init__(self):
  692. super(ReLUTuple, self).__init__()
  693. self.op = nn.ReLU()
  694. def construct(self, x):
  695. return self.op(x[0])
  696. func1 = nn.Softmax()
  697. func2 = nn.ReLU()
  698. func3 = ReLUTuple()
  699. funcs = (func1, func2)
  700. net = AddFuncNet(funcs, func3)
  701. inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
  702. i = Tensor(1, mstype.int32)
  703. with pytest.raises(ValueError) as err:
  704. net(i, inp)
  705. def test_switch_layer_dtype_join_failed():
  706. class Cast(nn.Cell):
  707. def __init__(self, dtype):
  708. super(Cast, self).__init__()
  709. self.op = P.Cast()
  710. self.dtype = dtype
  711. def construct(self, x):
  712. y = self.op(x, self.dtype)
  713. return y + y
  714. class SwitchNegNet(nn.Cell):
  715. def __init__(self, funcs):
  716. super(SwitchNegNet, self).__init__()
  717. self.funcs = funcs
  718. self.op = P.Neg()
  719. def construct(self, i, inputs):
  720. x = self.funcs[i](inputs)
  721. x = self.op(x)
  722. return x
  723. func1 = nn.ReLU()
  724. func2 = Cast(mstype.int32)
  725. funcs = (func1, func2)
  726. net = SwitchNegNet(funcs)
  727. inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
  728. i = Tensor(0, mstype.int32)
  729. with pytest.raises(TypeError) as err:
  730. net(i, inp)
  731. def test_large_for_loop():
  732. class Net(nn.Cell):
  733. def __init__(self):
  734. super(Net, self).__init__()
  735. self.flatten = P.ReLU() # nn.Flatten()
  736. def construct(self, x):
  737. for elem in range(1, 1900):
  738. x = self.flatten(x + elem)
  739. return x
  740. t = Tensor(np.ones([2, 3], dtype=np.float32))
  741. net = Net()
  742. os.environ['MS_DEV_RECURSIVE_EVAL'] = '1'
  743. old_max_call_depth = context.get_context('max_call_depth')
  744. context.set_context(max_call_depth=60)
  745. with pytest.raises(RuntimeError) as err:
  746. net(t)
  747. context.set_context(max_call_depth=old_max_call_depth)
  748. os.environ['MS_DEV_RECURSIVE_EVAL'] = '0'
  749. assert 'Exceed function call depth limit 60' in str(err.value)
  750. def test_large_for_loop_case2():
  751. class Menet(nn.Cell):
  752. def __init__(self, axis, flag_boottom, flag_top):
  753. super(Menet, self).__init__()
  754. self.squeeze = P.Squeeze(axis)
  755. self.expanddims = P.ExpandDims()
  756. self.flatten = nn.Flatten()
  757. self.neg = P.Neg()
  758. self.axis = axis
  759. self.flag_boottom = flag_boottom
  760. self.flag_top = flag_top
  761. def construct(self, x):
  762. if self.flag_boottom:
  763. x = self.neg(x)
  764. for i in range(0, 1500):
  765. x = self.expanddims(x, self.axis)
  766. x = self.squeeze(x)
  767. x = self.flatten(x)
  768. if self.flag_top:
  769. x = self.neg(x)
  770. return x
  771. x = Tensor(np.ones([2, 3], dtype=np.float32))
  772. net = Menet(axis=0, flag_boottom=True, flag_top=True)
  773. os.environ['MS_DEV_RECURSIVE_EVAL'] = '1'
  774. old_max_call_depth = context.get_context('max_call_depth')
  775. context.set_context(max_call_depth=80)
  776. with pytest.raises(RuntimeError) as err:
  777. net(x)
  778. os.environ['MS_DEV_RECURSIVE_EVAL'] = '0'
  779. context.set_context(max_call_depth=old_max_call_depth)
  780. assert 'Exceed function call depth limit 80' in str(err.value)
  781. def test_large_for_loop_with_continue_break():
  782. class Net(nn.Cell):
  783. def __init__(self):
  784. super(Net, self).__init__()
  785. self.flatten = P.ReLU() # nn.Flatten()
  786. def construct(self, x):
  787. idx = 0
  788. for elem1 in range(200):
  789. idx = idx + 1
  790. if idx < 10:
  791. x = x + 0.5
  792. continue
  793. if idx > 500:
  794. break
  795. x = self.flatten(x + elem1)
  796. return x
  797. os.environ['MS_DEV_RECURSIVE_EVAL'] = '1'
  798. old_max_call_depth = context.get_context('max_call_depth')
  799. context.set_context(max_call_depth=2000)
  800. t = Tensor(np.ones([2, 3], dtype=np.float32))
  801. net = Net()
  802. net(t)
  803. os.environ['MS_DEV_RECURSIVE_EVAL'] = '0'
  804. context.set_context(max_call_depth=old_max_call_depth)
  805. def test_recursive_call():
  806. class Net(nn.Cell):
  807. """ Net definition """
  808. def __init__(self):
  809. super(Net, self).__init__()
  810. self.fc = nn.Dense(10, 10) # padding=0
  811. # self.net2 = Net2()
  812. def construct(self, x):
  813. net2 = Net2()
  814. x = net2(x)
  815. out = self.fc(x)
  816. return out
  817. class Net2(nn.Cell):
  818. def __init__(self):
  819. super(Net2, self).__init__()
  820. self.net = Net()
  821. self.fc = nn.Dense(10, 10)
  822. def construct(self, x):
  823. x = self.net(x)
  824. out = self.fc(x)
  825. return out
  826. context.set_context(mode=context.GRAPH_MODE)
  827. os.environ['MS_DEV_RECURSIVE_EVAL'] = '1'
  828. old_max_call_depth = context.get_context('max_call_depth')
  829. context.set_context(max_call_depth=80)
  830. input_data = Tensor(np.identity(10).astype(np.float32))
  831. net = Net2()
  832. with pytest.raises(RuntimeError):
  833. net(input_data)
  834. os.environ['MS_DEV_RECURSIVE_EVAL'] = '0'
  835. context.set_context(max_call_depth=old_max_call_depth)
  836. # grad for Tensor(Bool) input and eliminate AddN(MakeTuple(Xs, zeros_like(Bool)))
  837. def test_grad_tensor_bool():
  838. class Net(nn.Cell):
  839. def __init__(self):
  840. super(Net, self).__init__()
  841. def construct(self, x, y, z):
  842. out = z
  843. while x:
  844. out = out + z
  845. x = y
  846. return out
  847. x = Tensor(np.array(False).astype(np.bool))
  848. y = Tensor(np.array(False).astype(np.bool))
  849. z = Tensor(np.ones([2, 3], dtype=np.float32))
  850. net = grad_all(Net())
  851. net(x, y, z)