You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_framstruct.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_framstruct """
  16. import pytest
  17. import numpy as np
  18. import mindspore as ms
  19. import mindspore.nn as nn
  20. from mindspore import context
  21. from mindspore.ops import composite as C
  22. from mindspore.ops import operations as P
  23. from mindspore.common.tensor import Tensor
  24. from mindspore.common.parameter import Parameter, ParameterTuple
  25. from mindspore.common.initializer import initializer
  26. from mindspore.common import dtype as mstype
  27. import mindspore.nn as nn
  28. from mindspore.nn.wrap.cell_wrapper import WithGradCell, WithLossCell
  29. from ..ut_filter import non_graph_engine
  30. from ....mindspore_test_framework.utils.check_gradient import (
  31. ms_function, check_jacobian, Tensor, NNGradChecker,
  32. OperationGradChecker, check_gradient, ScalarGradChecker)
  33. from ....mindspore_test_framework.utils.bprop_util import bprop
  34. import mindspore.context as context
  35. from mindspore.ops._grad.grad_base import bprop_getters
  36. from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
  37. def setup_module(module):
  38. context.set_context(mode=context.PYNATIVE_MODE)
  39. @ms_function
  40. def while_upper_bound(upper):
  41. rval = 2
  42. while rval < upper:
  43. rval = rval * rval
  44. return rval
  45. def test_while_upper_bound():
  46. res = while_upper_bound(10)
  47. assert res == 16
  48. @ms_function
  49. def while_lower_bound(lower):
  50. """ t_while """
  51. rval = lower
  52. while rval < 100:
  53. rval = rval * rval
  54. return rval
  55. def test_while_lower_bound():
  56. res = while_lower_bound(2)
  57. assert res == 256
  58. @ms_function
  59. def dynamic_make_tuple(x, lower, upper):
  60. out = ()
  61. i = lower
  62. while i < upper:
  63. out = out + (x,)
  64. i = i + 1
  65. return out
  66. def test_dynamic_make_tuple():
  67. # Dynamicly recursively creating static type is invalid in mindspore, as mindspore is a static language.
  68. with pytest.raises(RuntimeError):
  69. dynamic_make_tuple(2, 1, 5)
  70. def test_make_tuple():
  71. # Staticly recursively creating static type is valid in mindspore.
  72. @ms_function
  73. def make_tuple(x):
  74. out = ()
  75. for i in range(3):
  76. out = out + (x,)
  77. return out
  78. res = make_tuple(5)
  79. assert res == (5, 5, 5)
  80. @ms_function
  81. def add(x, y):
  82. """ add """
  83. return x + y
  84. def mul(x, y):
  85. """ mul """
  86. return x * y
  87. def add_mul(x, y):
  88. """ add_mul """
  89. return (x + y) * y
  90. def mainf(x, y):
  91. """ mainf """
  92. return C.grad_all(mul)(x, y)
  93. def grad_add_mul(x, y):
  94. """ grad_add_mul """
  95. return C.grad_all(add_mul)(x, y)
  96. @ms_function
  97. def sub(x, y):
  98. """ sub """
  99. return x - y
  100. @ms_function
  101. def if_always_true(x):
  102. """ if_always_true """
  103. if True:
  104. return x
  105. else:
  106. return 0
  107. def test_add():
  108. """ test_add """
  109. res = add(2.5, 3)
  110. assert res == 5.5
  111. def test_sub():
  112. """ test_sub """
  113. res = sub(3.5, 3)
  114. assert res == 0.5
  115. @non_graph_engine
  116. def test_if_always_true():
  117. """ test_if_always_true """
  118. res = if_always_true(1)
  119. assert res == 1
  120. @non_graph_engine
  121. def test_f():
  122. """ test_f """
  123. res = mainf(3, 2)
  124. assert res == (2, 3)
  125. @non_graph_engine
  126. def test_grad_add_mul():
  127. """ test_grad_add_mul """
  128. res = grad_add_mul(3, 2)
  129. assert res == (2, 7)
  130. def f(x):
  131. if x > 0:
  132. return f(x-1)
  133. return x
  134. @ms_function
  135. def list_subscript():
  136. """ list_subscript """
  137. x= [1, 2, 3]
  138. return x[0] * x[1]
  139. def test_list_subscript():
  140. """ test_list_subscript """
  141. res = list_subscript()
  142. assert res == 2
  143. @ms_function
  144. def ms_infer_for(xs, y):
  145. """ ms_infer_for """
  146. rval = y
  147. for x in xs:
  148. rval = rval + x
  149. return rval
  150. def test_infer_for():
  151. """ test_infer_for """
  152. t = (1, 2, 3)
  153. y = 4
  154. res = ms_infer_for(t, y)
  155. assert res == 10
  156. @ms_function
  157. def if_construct(a, b):
  158. z = a
  159. if a > b:
  160. z = a+b
  161. else:
  162. z = a*b
  163. if z > b:
  164. return z-a
  165. else:
  166. return a-b
  167. def test_if_construct():
  168. """ test_if_construct """
  169. res = if_construct(3, 6)
  170. assert res == 15
  171. @ms_function
  172. def if_scalar(a, b):
  173. """ if_abstract """
  174. if a:
  175. return a
  176. return b
  177. def test_if_scalar1():
  178. """ test_if_abstract """
  179. res = if_scalar(3, 6)
  180. assert res == 3
  181. def test_if_scalar2():
  182. """ test_if_abstract """
  183. res = if_scalar(0, 6)
  184. assert res == 6
  185. @ms_function
  186. def if_tensor(a, b):
  187. c = a
  188. if a < b:
  189. c = a+a
  190. if c < b:
  191. c = a+c
  192. else:
  193. c = a+b
  194. else:
  195. c = b+b
  196. out = c + c
  197. return out
  198. def test_if_tensor():
  199. res = if_tensor(Tensor(np.ones([64, 10]).astype(np.int32)), Tensor(np.ones([64, 10]).astype(np.int32)))
  200. assert res == Tensor(np.ones([64, 10]).astype(np.int32) * 4)
  201. @ms_function
  202. def rec(x):
  203. """ rec """
  204. if x > 0:
  205. return rec(x-1)
  206. return x
  207. def test_grad_rec():
  208. """ test_grad_rec """
  209. res = C.grad(rec)(10)
  210. assert res == 1
  211. def test_me_rec():
  212. """ test_me_rec """
  213. res = rec(10)
  214. assert res == 0
  215. @ms_function
  216. def t2_while(x, y):
  217. out = y - x
  218. i = 0
  219. while i < 10:
  220. out = mul(x, y)
  221. i = i + 1
  222. return out
  223. def test_while2():
  224. res = t2_while(2, 3)
  225. assert res == 6
  226. def test_grad_while2():
  227. res = C.grad(t2_while)(2, 3)
  228. assert res == 3
  229. def if_test(a, b):
  230. """ if_test """
  231. if a > b:
  232. return 3 * a
  233. return 2 * b
  234. def grad_if(x, y):
  235. """ grad_if """
  236. return C.grad_all(if_test)(x, y)
  237. def test_grad_if():
  238. """ test_grad_if """
  239. assert grad_if(5, 4) == (3, 0)
  240. # While loop is not unrolled in forward and backward graphs.
  241. def test_dont_unroll_while():
  242. def dont_unroll_while(x, y):
  243. i = 2
  244. out = y - x
  245. while i < 10:
  246. out = mul(x, y)
  247. i = i + 1
  248. return out
  249. @ms_function()
  250. def invoke_while(x, y):
  251. return C.grad(dont_unroll_while)(x, y)
  252. res = invoke_while(2, 3)
  253. assert res == 3
  254. class ConvNet(nn.Cell):
  255. def __init__(self):
  256. super(ConvNet, self).__init__()
  257. out_channel = 16
  258. kernel_size = 3
  259. self.conv = P.Conv2D(out_channel,
  260. kernel_size,
  261. mode=1,
  262. pad_mode="pad",
  263. pad=0,
  264. stride=1,
  265. dilation=2,
  266. group=1)
  267. self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w')
  268. def construct(self, x):
  269. return self.conv(x, self.w)
  270. conv = ConvNet()
  271. c1 = Tensor([2], mstype.float32)
  272. c2 = Tensor([10], mstype.float32)
  273. c3 = Tensor([1], mstype.float32)
  274. @ms_function
  275. def t1_while(x, y, z):
  276. out = x
  277. i = c1
  278. while i < c2:
  279. out = out + conv(z)
  280. i = i + c3
  281. out = out + out
  282. return out
  283. def test_while_net():
  284. y = Tensor(np.ones([1,3,3,4]).astype(np.float32))
  285. x = Tensor(np.ones([1,16,12,12]).astype(np.float32))
  286. z = Tensor(np.ones([1,16,16,16]).astype(np.float32))
  287. res = t1_while(x, y, z)
  288. assert res == Tensor(np.ones([1,16,12,12]).astype(np.float32) * 2306.0)
  289. @ms_function
  290. def if_while(a, b, x, z):
  291. c = a
  292. i = c1
  293. out = x
  294. if a < b:
  295. c = a+a
  296. while i < c2:
  297. out = out + conv(z)
  298. i = i + c3
  299. else:
  300. c = b+b
  301. out = c + c
  302. return out
  303. def test_if_while():
  304. x = Tensor(np.random.randn(1,16,12,12).astype(np.float32))
  305. z = Tensor(np.random.randn(1,16,16,16).astype(np.float32))
  306. res = if_while(Tensor(np.ones([64, 10]).astype(np.float32)), Tensor(np.ones([64, 10]).astype(np.float32)), x, z)
  307. assert res == Tensor(np.ones([64, 10]).astype(np.float32) * 4.0)
  308. def _while(x):
  309. """ _while """
  310. ret = x * x
  311. i = 2
  312. while i <= 3:
  313. ret = ret * i
  314. i = i + 1
  315. return ret
  316. def grad_while(x):
  317. """ grad_while """
  318. return C.grad_all(_while)(x)
  319. def test_grad_while():
  320. """ test_grad_while """
  321. assert grad_while(5) == (60,)
  322. @ms_function
  323. def factorial(n):
  324. """ factorial """
  325. if n == 0:
  326. return 1
  327. return n * factorial(n-1)
  328. def test_factorial():
  329. res = factorial(3)
  330. assert res == 6
  331. def test_grad_factorial():
  332. res = C.grad(factorial)(3)
  333. assert res == 11
  334. def _for(x):
  335. """ _for """
  336. ret = x * x
  337. for i in (2, 3):
  338. ret = ret * i
  339. return ret
  340. def grad_for(x):
  341. """ grad_for """
  342. return C.grad_all(_for)(x)
  343. def test_grad_for():
  344. """ test_grad_for """
  345. assert grad_for(5) == (60,)
  346. @ms_function
  347. def try_tail(x):
  348. """ try_tail """
  349. return C.tail(x)
  350. @non_graph_engine
  351. def test_tail():
  352. """ test_tail """
  353. try_tail((0, 1, 2, 3))
  354. @ms_function
  355. def zero_like_tensor(x):
  356. """ zero_like_tensor """
  357. return C.zeros_like(x)
  358. def test_zeros():
  359. """ test_zeros """
  360. x = Tensor(np.ones([2, 3]).astype(np.int32))
  361. res = zero_like_tensor(x)
  362. assert res == Tensor(np.zeros([2, 3]).astype(np.int32))
  363. def test_ScalarGradChecker():
  364. """ test_ScalarGradChecker """
  365. def scalar_f(x, y):
  366. return x * y
  367. check_gradient(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, sampling_times=1)
  368. def test_GradCheckerPrimitive():
  369. """ test_GradCheckerPrimitive """
  370. matmul = P.MatMul()
  371. def prim_f(x, y):
  372. return matmul(x, y)
  373. check_gradient(prim_f, Tensor(np.array([[0.65, 0.8, 0.8]], np.float32)),
  374. Tensor(np.array([[0.1], [0.2], [-.1]], np.float32)),
  375. grad_checker_class=OperationGradChecker, sampling_times=2)
  376. def test_NNGradChecker():
  377. """ test_NNGradChecker """
  378. class Net(nn.Cell):
  379. """ Net definition """
  380. def __init__(self):
  381. super(Net, self).__init__()
  382. self.dense = nn.Dense(10, 10)
  383. def construct(self, x):
  384. out = self.dense(x)
  385. return out
  386. check_gradient(Net(), Tensor(np.random.rand(1, 10).astype(np.float32)),
  387. delta=1e-3,
  388. max_error=1e-3,
  389. grad_checker_class=NNGradChecker, sampling_times=3)
  390. def test_OperationGradChecker():
  391. """ test_OperationGradChecker """
  392. class Net(nn.Cell):
  393. """ Net definition """
  394. def __init__(self):
  395. super(Net, self).__init__()
  396. self.matmul = P.MatMul()
  397. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  398. def construct(self, x, y):
  399. x = x * self.z
  400. out = self.matmul(x, y)
  401. return out
  402. check_gradient(Net(), Tensor(np.array([[0.65, 0.8, 0.8]], np.float32)),
  403. Tensor(np.array([[0.1], [0.2], [-.1]], np.float32)), grad_checker_class=OperationGradChecker,
  404. input_selector=[1], sampling_times=2)
  405. def test_ScalarJacobianChecker():
  406. """ test_ScalarJacobianChecker """
  407. def scalar_f(x, y):
  408. return x * y
  409. check_jacobian(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, input_selector=[0])
  410. def test_OperationJacobianChecker():
  411. """ test_OperationJacobianChecker """
  412. class Net(nn.Cell):
  413. """ Net definition """
  414. def __init__(self):
  415. super(Net, self).__init__()
  416. self.matmul = P.MatMul()
  417. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  418. def construct(self, x, y):
  419. x = x * self.z
  420. out = self.matmul(x, y)
  421. return x, out
  422. check_jacobian(Net(), Tensor(np.array([[0.65, 0.8, 0.8], [0.1, 0.2, 0.3]], np.float32)),
  423. Tensor(np.array([[0.1, 0.3], [0.2, 0.2], [-.1, 0.4]], np.float32)),
  424. grad_checker_class=OperationGradChecker, input_selector=[0],
  425. output_selector=[0])
  426. def test_NNJacobianChecker():
  427. """ test_NNJacobianChecker """
  428. class Net(nn.Cell):
  429. """ Net definition """
  430. def __init__(self):
  431. super(Net, self).__init__()
  432. self.dense = nn.Dense(10, 10)
  433. def construct(self, x):
  434. out = self.dense(x)
  435. return out, x
  436. check_jacobian(Net(), Tensor(np.random.rand(1, 10).astype(np.float32)),
  437. delta=1e-3,
  438. max_error=1e-7,
  439. grad_checker_class=NNGradChecker,
  440. input_selector=[1],
  441. output_selector=[0])
  442. def multi_outputs(x, y):
  443. z = x + y
  444. return 2 * z, 2 * z
  445. def test_grad_multi_outputs():
  446. assert C.grad_all_with_sens(multi_outputs)(2, 3, (1, 1)) == (4, 4)
  447. @ms_function
  448. def while_sp(x, y, z):
  449. out = x
  450. i = c3
  451. while i < c2:
  452. out = mul(x, out)
  453. i = i + c3
  454. return out
  455. def test_while_sp():
  456. y = Tensor(np.ones([1, 3]).astype(np.float32))
  457. z = Tensor(np.ones([1, 3]).astype(np.float32))
  458. x = Tensor(np.ones([1, 3]).astype(np.float32) * 2.0)
  459. res = while_sp(x, y, z)
  460. assert res == Tensor(np.ones([1, 3]).astype(np.float32) * 1024.0)
  461. def grad_refactor_simple_1(x, y):
  462. """ add """
  463. return x * x + 2 * y
  464. def test_grad_refactor_simple_1():
  465. assert C.grad_all(grad_refactor_simple_1)(2, 1) == (4, 2)
  466. def grad_refactor_simple_2(x, y, z):
  467. """ add """
  468. return x * y + z + x * y * z + x + x * y
  469. def test_grad_refactor_simple_2():
  470. assert C.grad_all(grad_refactor_simple_2)(2, 3, 0) == (7, 4, 7)
  471. def grad_refactor_1(a, b):
  472. """ if_test """
  473. def inner(x, y):
  474. return x * y
  475. return inner(a, b)
  476. def test_grad_refactor_1():
  477. assert C.grad_all(grad_refactor_1)(2, 3) == (3, 2)
  478. def grad_refactor_2(a, b):
  479. """ if_test """
  480. def inner(x):
  481. return x * b
  482. return inner(b) * inner(a)
  483. def test_grad_refactor_2():
  484. assert C.grad_all(grad_refactor_2)(2, 3) == (27, 54)
  485. def grad_refactor_3(a):
  486. """ if_test """
  487. if a > 3:
  488. return 0
  489. return 3 * a
  490. def test_grad_refactor_3():
  491. assert C.grad_all(grad_refactor_3)(3) == (3,)
  492. def grad_refactor_4(a):
  493. """ if_test """
  494. if a > 3:
  495. return 3 * a
  496. return 0
  497. def test_grad_refactor_4():
  498. assert C.grad_all(grad_refactor_4)(4) == (3,)
  499. def grad_refactor_5(a):
  500. """ if_test """
  501. if a > 3:
  502. return 1
  503. return a
  504. def test_grad_refactor_5():
  505. assert C.grad_all(grad_refactor_5)(1) == (1,)
  506. def grad_refactor_6(a, b):
  507. """ if_test """
  508. if a > b:
  509. return 3 * a + b
  510. return 2 * b * a
  511. def test_grad_refactor_6():
  512. C.grad_all(grad_refactor_6)(3, 2) == (3, 1)
  513. def grad_refactor_while(x):
  514. """ grad_refactor_while """
  515. rval = x
  516. while rval < 4:
  517. rval = rval * rval
  518. return rval
  519. def test_grad_refactor_9():
  520. assert C.grad_all(grad_refactor_while)(3) == (6,)
  521. def grad_refactor__while_1(x):
  522. """ _while """
  523. ret = x * x
  524. i = 2
  525. while i <= 3:
  526. ret = ret * i
  527. i = i + 1
  528. return ret
  529. def test_grad_refactor_10():
  530. """ test_grad_while """
  531. assert C.grad_all(grad_refactor__while_1)(5) == (60,)
  532. def test_grad_refactor_11():
  533. class Net(nn.Cell):
  534. """ Net definition """
  535. def __init__(self):
  536. super(Net, self).__init__()
  537. def construct(self, x, y):
  538. return x * y * y
  539. net = Net()
  540. C.grad_all(net)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32)))
  541. def test_grad_refactor_12():
  542. class Net(nn.Cell):
  543. """ Net definition """
  544. def __init__(self):
  545. super(Net, self).__init__()
  546. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  547. def construct(self, x, y):
  548. return x * self.z * y
  549. net = Net()
  550. C.grad_all(net)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.zeros([2]).astype(np.float32)))
  551. def test_grad_refactor_13():
  552. class Net(nn.Cell):
  553. """ Net definition """
  554. def __init__(self):
  555. super(Net, self).__init__()
  556. self.z = Parameter(Tensor(np.ones([2]).astype(np.float32)), name='z')
  557. def construct(self, x, y):
  558. return x * self.z * y
  559. net = Net()
  560. weights = ParameterTuple(net.trainable_params())
  561. C.grad_by_list(net, weights)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.zeros([2]).astype(np.float32)))
  562. def grad_refactor_14(a, b):
  563. """ if_test """
  564. def inner1(x):
  565. return x * b
  566. def inner2(x):
  567. return a * b
  568. def inner3(x):
  569. if (x > 2):
  570. return a
  571. return b
  572. return inner1(b) + inner2(a) + inner3(a)
  573. def test_grad_refactor_14():
  574. assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9)
  575. class IfDeferInline(nn.Cell):
  576. def __init__(self, mul_size):
  577. super().__init__()
  578. self.mul_weight = Tensor(np.full(mul_size, 0.6, dtype=np.float32))
  579. self.mul = P.Mul()
  580. def construct(self, inputs):
  581. x = self.mul(inputs, self.mul_weight)
  582. if True:
  583. x = x
  584. return x
  585. def test_grad_if_defer_inline():
  586. """ test_grad_if_defer_inline """
  587. network = IfDeferInline([128, 96])
  588. network.add_flags(defer_inline=False)
  589. inp = Tensor(np.ones([128, 96]).astype(np.float32))
  590. grads = C.grad_all(network)(inp)
  591. assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)
  592. def test_bprop_with_wrong_output_num():
  593. class BpropWithWrongOutputNum(PrimitiveWithInfer):
  594. @prim_attr_register
  595. def __init__(self):
  596. super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum')
  597. def __call__(self, x, y):
  598. return x
  599. def infer_shape(self, x_shape, yshape):
  600. return x_shape
  601. def infer_dtype(self, x_type, y_type):
  602. return x_type
  603. @bprop_getters.register(BpropWithWrongOutputNum)
  604. def get_bprop_with_wrong_output_num(self):
  605. """Generate bprop for BpropWithWrongOutputNum"""
  606. def bprop(x, y, out, dout):
  607. return (dout,)
  608. return bprop
  609. class BpropWithWrongOutputNumCell(nn.Cell):
  610. def __init__(self):
  611. super(BpropWithWrongOutputNumCell, self).__init__()
  612. def construct(self, x, y):
  613. return BpropWithWrongOutputNum()(x, y)
  614. with pytest.raises(TypeError):
  615. C.grad_all(BpropWithWrongOutputNumCell())(1, 2)
  616. def test_bprop_with_wrong_output_type():
  617. class BpropWithWrongOutputType(PrimitiveWithInfer):
  618. @prim_attr_register
  619. def __init__(self):
  620. super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType')
  621. def __call__(self, x):
  622. return x
  623. def infer_shape(self, x_shape):
  624. return x_shape
  625. def infer_dtype(self, x_type):
  626. return x_type
  627. @bprop_getters.register(BpropWithWrongOutputType)
  628. def get_bprop_with_wrong_output_type(self):
  629. """Generate bprop for BpropWithWrongOutputType"""
  630. def bprop(x, out, dout):
  631. return (1,)
  632. return bprop
  633. class BpropWithWrongOutputTypeCell(nn.Cell):
  634. def __init__(self):
  635. super(BpropWithWrongOutputTypeCell, self).__init__()
  636. def construct(self, x):
  637. return BpropWithWrongOutputType()(x)
  638. with pytest.raises(TypeError):
  639. C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
  640. def test_bprop_with_wrong_output_shape():
  641. class BpropWithWrongOutputShape(PrimitiveWithInfer):
  642. @prim_attr_register
  643. def __init__(self):
  644. super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape')
  645. def __call__(self, x):
  646. return x
  647. def infer_shape(self, x_shape):
  648. return x_shape
  649. def infer_dtype(self, x_type):
  650. return x_type
  651. @bprop_getters.register(BpropWithWrongOutputShape)
  652. def get_bprop_with_wrong_output_shape(self):
  653. """Generate bprop for BpropWithWrongOutputShape"""
  654. ones = Tensor(np.ones([2,]).astype(np.int32))
  655. def bprop(x, out, dout):
  656. return (ones,)
  657. return bprop
  658. class BpropWithWrongOutputShapeCell(nn.Cell):
  659. def __init__(self):
  660. super(BpropWithWrongOutputShapeCell, self).__init__()
  661. def construct(self, x):
  662. return BpropWithWrongOutputShape()(x)
  663. with pytest.raises(TypeError):
  664. C.grad_all(BpropWithWrongOutputShapeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))