You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_framstruct.py 21 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_framstruct """
  16. import pytest
  17. import numpy as np
  18. import mindspore.nn as nn
  19. from mindspore import context
  20. from mindspore.ops import composite as C
  21. from mindspore.ops import operations as P
  22. from mindspore.common.tensor import Tensor
  23. from mindspore.common.parameter import Parameter, ParameterTuple
  24. from mindspore.common import dtype as mstype
  25. from ..ut_filter import non_graph_engine
  26. from ....mindspore_test_framework.utils.check_gradient import (
  27. ms_function, check_jacobian, Tensor, NNGradChecker,
  28. OperationGradChecker, check_gradient, ScalarGradChecker)
  29. from mindspore.ops._grad.grad_base import bprop_getters
  30. from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
  31. def setup_module(module):
  32. context.set_context(mode=context.PYNATIVE_MODE)
  33. @ms_function
  34. def while_upper_bound(upper):
  35. rval = 2
  36. while rval < upper:
  37. rval = rval * rval
  38. return rval
  39. def test_while_upper_bound():
  40. res = while_upper_bound(10)
  41. assert res == 16
  42. @ms_function
  43. def while_lower_bound(lower):
  44. """ t_while """
  45. rval = lower
  46. while rval < 100:
  47. rval = rval * rval
  48. return rval
  49. def test_while_lower_bound():
  50. res = while_lower_bound(2)
  51. assert res == 256
  52. @ms_function
  53. def dynamic_make_tuple(x, lower, upper):
  54. out = ()
  55. i = lower
  56. while i < upper:
  57. out = out + (x,)
  58. i = i + 1
  59. return out
  60. def test_dynamic_make_tuple():
  61. # Dynamicly recursively creating static type is invalid in mindspore, as mindspore is a static language.
  62. with pytest.raises(RuntimeError):
  63. dynamic_make_tuple(2, 1, 5)
  64. def test_make_tuple():
  65. # Staticly recursively creating static type is valid in mindspore.
  66. @ms_function
  67. def make_tuple(x):
  68. out = ()
  69. for i in range(3):
  70. out = out + (x,)
  71. return out
  72. res = make_tuple(5)
  73. assert res == (5, 5, 5)
  74. @ms_function
  75. def add(x, y):
  76. """ add """
  77. return x + y
  78. def mul(x, y):
  79. """ mul """
  80. return x * y
  81. def add_mul(x, y):
  82. """ add_mul """
  83. return (x + y) * y
  84. def mainf(x, y):
  85. """ mainf """
  86. return C.grad_all(mul)(x, y)
  87. def grad_add_mul(x, y):
  88. """ grad_add_mul """
  89. return C.grad_all(add_mul)(x, y)
  90. @ms_function
  91. def sub(x, y):
  92. """ sub """
  93. return x - y
  94. @ms_function
  95. def if_always_true(x):
  96. """ if_always_true """
  97. if True:
  98. return x
  99. else:
  100. return 0
  101. def test_add():
  102. """ test_add """
  103. res = add(2.5, 3)
  104. assert res == 5.5
  105. def test_sub():
  106. """ test_sub """
  107. res = sub(3.5, 3)
  108. assert res == 0.5
  109. @non_graph_engine
  110. def test_if_always_true():
  111. """ test_if_always_true """
  112. res = if_always_true(1)
  113. assert res == 1
  114. @non_graph_engine
  115. def test_f():
  116. """ test_f """
  117. res = mainf(3, 2)
  118. assert res == (2, 3)
  119. @non_graph_engine
  120. def test_grad_add_mul():
  121. """ test_grad_add_mul """
  122. res = grad_add_mul(3, 2)
  123. assert res == (2, 7)
  124. def f(x):
  125. if x > 0:
  126. return f(x-1)
  127. return x
  128. @ms_function
  129. def list_subscript():
  130. """ list_subscript """
  131. x= [1, 2, 3]
  132. return x[0] * x[1]
  133. def test_list_subscript():
  134. """ test_list_subscript """
  135. res = list_subscript()
  136. assert res == 2
  137. @ms_function
  138. def ms_infer_for(xs, y):
  139. """ ms_infer_for """
  140. rval = y
  141. for x in xs:
  142. rval = rval + x
  143. return rval
  144. def test_infer_for():
  145. """ test_infer_for """
  146. t = (1, 2, 3)
  147. y = 4
  148. res = ms_infer_for(t, y)
  149. assert res == 10
  150. @ms_function
  151. def if_construct(a, b):
  152. z = a
  153. if a > b:
  154. z = a+b
  155. else:
  156. z = a*b
  157. if z > b:
  158. return z-a
  159. else:
  160. return a-b
  161. def test_if_construct():
  162. """ test_if_construct """
  163. res = if_construct(3, 6)
  164. assert res == 15
  165. @ms_function
  166. def if_scalar(a, b):
  167. """ if_abstract """
  168. if a:
  169. return a
  170. return b
  171. def test_if_scalar1():
  172. """ test_if_abstract """
  173. res = if_scalar(3, 6)
  174. assert res == 3
  175. def test_if_scalar2():
  176. """ test_if_abstract """
  177. res = if_scalar(0, 6)
  178. assert res == 6
  179. @ms_function
  180. def if_tensor(a, b):
  181. c = a
  182. if a < b:
  183. c = a+a
  184. if c < b:
  185. c = a+c
  186. else:
  187. c = a+b
  188. else:
  189. c = b+b
  190. out = c + c
  191. return out
  192. def test_if_tensor():
  193. res = if_tensor(Tensor(np.ones([64, 10]).astype(np.int32)), Tensor(np.ones([64, 10]).astype(np.int32)))
  194. assert res == Tensor(np.ones([64, 10]).astype(np.int32) * 4)
  195. @ms_function
  196. def rec(x):
  197. """ rec """
  198. if x > 0:
  199. return rec(x-1)
  200. return x
  201. def test_grad_rec():
  202. """ test_grad_rec """
  203. res = C.grad(rec)(10)
  204. assert res == 1
  205. def test_me_rec():
  206. """ test_me_rec """
  207. res = rec(10)
  208. assert res == 0
  209. @ms_function
  210. def t2_while(x, y):
  211. out = y - x
  212. i = 0
  213. while i < 10:
  214. out = mul(x, y)
  215. i = i + 1
  216. return out
  217. def test_while2():
  218. res = t2_while(2, 3)
  219. assert res == 6
  220. def test_grad_while2():
  221. res = C.grad(t2_while)(2, 3)
  222. assert res == 3
  223. def if_test(a, b):
  224. """ if_test """
  225. if a > b:
  226. return 3 * a
  227. return 2 * b
  228. def grad_if(x, y):
  229. """ grad_if """
  230. return C.grad_all(if_test)(x, y)
  231. def test_grad_if():
  232. """ test_grad_if """
  233. assert grad_if(5, 4) == (3, 0)
  234. # While loop is not unrolled in forward and backward graphs.
  235. def test_dont_unroll_while():
  236. def dont_unroll_while(x, y):
  237. i = 2
  238. out = y - x
  239. while i < 10:
  240. out = mul(x, y)
  241. i = i + 1
  242. return out
  243. @ms_function()
  244. def invoke_while(x, y):
  245. return C.grad(dont_unroll_while)(x, y)
  246. res = invoke_while(2, 3)
  247. assert res == 3
  248. class ConvNet(nn.Cell):
  249. def __init__(self):
  250. super(ConvNet, self).__init__()
  251. out_channel = 16
  252. kernel_size = 3
  253. self.conv = P.Conv2D(out_channel,
  254. kernel_size,
  255. mode=1,
  256. pad_mode="pad",
  257. pad=0,
  258. stride=1,
  259. dilation=2,
  260. group=1)
  261. self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w')
  262. def construct(self, x):
  263. return self.conv(x, self.w)
  264. conv = ConvNet()
  265. c1 = Tensor([2], mstype.float32)
  266. c2 = Tensor([10], mstype.float32)
  267. c3 = Tensor([1], mstype.float32)
  268. @ms_function
  269. def t1_while(x, y, z):
  270. out = x
  271. i = c1
  272. while i < c2:
  273. out = out + conv(z)
  274. i = i + c3
  275. out = out + out
  276. return out
  277. def test_while_net():
  278. y = Tensor(np.ones([1,3,3,4]).astype(np.float32))
  279. x = Tensor(np.ones([1,16,12,12]).astype(np.float32))
  280. z = Tensor(np.ones([1,16,16,16]).astype(np.float32))
  281. res = t1_while(x, y, z)
  282. assert res == Tensor(np.ones([1,16,12,12]).astype(np.float32) * 2306.0)
  283. @ms_function
  284. def if_while(a, b, x, z):
  285. c = a
  286. i = c1
  287. out = x
  288. if a < b:
  289. c = a+a
  290. while i < c2:
  291. out = out + conv(z)
  292. i = i + c3
  293. else:
  294. c = b+b
  295. out = c + c
  296. return out
  297. def test_if_while():
  298. x = Tensor(np.random.randn(1,16,12,12).astype(np.float32))
  299. z = Tensor(np.random.randn(1,16,16,16).astype(np.float32))
  300. res = if_while(Tensor(np.ones([64, 10]).astype(np.float32)), Tensor(np.ones([64, 10]).astype(np.float32)), x, z)
  301. assert res == Tensor(np.ones([64, 10]).astype(np.float32) * 4.0)
  302. def _while(x):
  303. """ _while """
  304. ret = x * x
  305. i = 2
  306. while i <= 3:
  307. ret = ret * i
  308. i = i + 1
  309. return ret
  310. def grad_while(x):
  311. """ grad_while """
  312. return C.grad_all(_while)(x)
  313. def test_grad_while():
  314. """ test_grad_while """
  315. assert grad_while(5) == (60,)
  316. @ms_function
  317. def factorial(n):
  318. """ factorial """
  319. if n == 0:
  320. return 1
  321. return n * factorial(n-1)
  322. def test_factorial():
  323. res = factorial(3)
  324. assert res == 6
  325. def test_grad_factorial():
  326. res = C.grad(factorial)(3)
  327. assert res == 11
  328. @ms_function
  329. def factorial2(n):
  330. """ factorial """
  331. if n != 0:
  332. return n * factorial2(n-1)
  333. elif n == 1:
  334. return 1 * factorial2(n-1)
  335. else:
  336. return 1
  337. def test_factorial2():
  338. res = factorial2(3)
  339. assert res == 6
  340. @ms_function
  341. def foo(n):
  342. if n <= 1:
  343. if n == 1:
  344. return foo(n-1)
  345. else:
  346. return 1
  347. else:
  348. return foo(n-1)
  349. def test_foo():
  350. res = foo(5)
  351. assert res == 1
  352. @ms_function
  353. def double_nested_loop(x):
  354. i = 0
  355. s = 0
  356. while(i < x):
  357. j = 0
  358. i = i + 1
  359. while(j < 3):
  360. j = j + 1
  361. s = s + j
  362. return s
  363. def test_nested_loop():
  364. res = double_nested_loop(3)
  365. assert res == 18
  366. @ms_function
  367. def double_nested_loop2(x):
  368. s = 0
  369. for i in range(x):
  370. for j in range(3):
  371. s = s + j
  372. return s
  373. def test_nested_loop2():
  374. res = double_nested_loop(1)
  375. assert res == 6
  376. def _for(x):
  377. """ _for """
  378. ret = x * x
  379. for i in (2, 3):
  380. ret = ret * i
  381. return ret
  382. def grad_for(x):
  383. """ grad_for """
  384. return C.grad_all(_for)(x)
  385. def test_grad_for():
  386. """ test_grad_for """
  387. assert grad_for(5) == (60,)
  388. @ms_function
  389. def try_tail(x):
  390. """ try_tail """
  391. return C.tail(x)
  392. @non_graph_engine
  393. def test_tail():
  394. """ test_tail """
  395. try_tail((0, 1, 2, 3))
  396. @ms_function
  397. def zero_like_tensor(x):
  398. """ zero_like_tensor """
  399. return C.zeros_like(x)
  400. def test_zeros():
  401. """ test_zeros """
  402. x = Tensor(np.ones([2, 3]).astype(np.int32))
  403. res = zero_like_tensor(x)
  404. assert res == Tensor(np.zeros([2, 3]).astype(np.int32))
  405. def test_ScalarGradChecker():
  406. """ test_ScalarGradChecker """
  407. def scalar_f(x, y):
  408. return x * y
  409. check_gradient(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, sampling_times=1)
  410. def test_GradCheckerPrimitive():
  411. """ test_GradCheckerPrimitive """
  412. matmul = P.MatMul()
  413. def prim_f(x, y):
  414. return matmul(x, y)
  415. check_gradient(prim_f, Tensor(np.array([[0.65, 0.8, 0.8]], np.float32)),
  416. Tensor(np.array([[0.1], [0.2], [-.1]], np.float32)),
  417. grad_checker_class=OperationGradChecker, sampling_times=2)
  418. def test_NNGradChecker():
  419. """ test_NNGradChecker """
  420. class Net(nn.Cell):
  421. """ Net definition """
  422. def __init__(self):
  423. super(Net, self).__init__()
  424. self.dense = nn.Dense(10, 10)
  425. def construct(self, x):
  426. out = self.dense(x)
  427. return out
  428. check_gradient(Net(), Tensor(np.random.rand(1, 10).astype(np.float32)),
  429. delta=1e-3,
  430. max_error=1e-3,
  431. grad_checker_class=NNGradChecker, sampling_times=3)
  432. def test_OperationGradChecker():
  433. """ test_OperationGradChecker """
  434. class Net(nn.Cell):
  435. """ Net definition """
  436. def __init__(self):
  437. super(Net, self).__init__()
  438. self.matmul = P.MatMul()
  439. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  440. def construct(self, x, y):
  441. x = x * self.z
  442. out = self.matmul(x, y)
  443. return out
  444. check_gradient(Net(), Tensor(np.array([[0.65, 0.8, 0.8]], np.float32)),
  445. Tensor(np.array([[0.1], [0.2], [-.1]], np.float32)), grad_checker_class=OperationGradChecker,
  446. input_selector=[1], sampling_times=2)
  447. def test_ScalarJacobianChecker():
  448. """ test_ScalarJacobianChecker """
  449. def scalar_f(x, y):
  450. return x * y
  451. check_jacobian(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, input_selector=[0])
  452. def test_OperationJacobianChecker():
  453. """ test_OperationJacobianChecker """
  454. class Net(nn.Cell):
  455. """ Net definition """
  456. def __init__(self):
  457. super(Net, self).__init__()
  458. self.matmul = P.MatMul()
  459. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  460. def construct(self, x, y):
  461. x = x * self.z
  462. out = self.matmul(x, y)
  463. return x, out
  464. check_jacobian(Net(), Tensor(np.array([[0.65, 0.8, 0.8], [0.1, 0.2, 0.3]], np.float32)),
  465. Tensor(np.array([[0.1, 0.3], [0.2, 0.2], [-.1, 0.4]], np.float32)),
  466. grad_checker_class=OperationGradChecker, input_selector=[0],
  467. output_selector=[0])
  468. def test_NNJacobianChecker():
  469. """ test_NNJacobianChecker """
  470. class Net(nn.Cell):
  471. """ Net definition """
  472. def __init__(self):
  473. super(Net, self).__init__()
  474. self.dense = nn.Dense(10, 10)
  475. def construct(self, x):
  476. out = self.dense(x)
  477. return out, x
  478. check_jacobian(Net(), Tensor(np.random.rand(1, 10).astype(np.float32)),
  479. delta=1e-3,
  480. max_error=1e-7,
  481. grad_checker_class=NNGradChecker,
  482. input_selector=[1],
  483. output_selector=[0])
  484. def multi_outputs(x, y):
  485. z = x + y
  486. return 2 * z, 2 * z
  487. def test_grad_multi_outputs():
  488. assert C.grad_all_with_sens(multi_outputs)(2, 3, (1, 1)) == (4, 4)
  489. @ms_function
  490. def while_sp(x, y, z):
  491. out = x
  492. i = c3
  493. while i < c2:
  494. out = mul(x, out)
  495. i = i + c3
  496. return out
  497. def test_while_sp():
  498. y = Tensor(np.ones([1, 3]).astype(np.float32))
  499. z = Tensor(np.ones([1, 3]).astype(np.float32))
  500. x = Tensor(np.ones([1, 3]).astype(np.float32) * 2.0)
  501. res = while_sp(x, y, z)
  502. assert res == Tensor(np.ones([1, 3]).astype(np.float32) * 1024.0)
  503. def grad_refactor_simple_1(x, y):
  504. """ add """
  505. return x * x + 2 * y
  506. def test_grad_refactor_simple_1():
  507. assert C.grad_all(grad_refactor_simple_1)(2, 1) == (4, 2)
  508. def grad_refactor_simple_2(x, y, z):
  509. """ add """
  510. return x * y + z + x * y * z + x + x * y
  511. def test_grad_refactor_simple_2():
  512. assert C.grad_all(grad_refactor_simple_2)(2, 3, 0) == (7, 4, 7)
  513. def grad_refactor_1(a, b):
  514. """ if_test """
  515. def inner(x, y):
  516. return x * y
  517. return inner(a, b)
  518. def test_grad_refactor_1():
  519. assert C.grad_all(grad_refactor_1)(2, 3) == (3, 2)
  520. def grad_refactor_2(a, b):
  521. """ if_test """
  522. def inner(x):
  523. return x * b
  524. return inner(b) * inner(a)
  525. def test_grad_refactor_2():
  526. assert C.grad_all(grad_refactor_2)(2, 3) == (27, 54)
  527. def grad_refactor_3(a):
  528. """ if_test """
  529. if a > 3:
  530. return 0
  531. return 3 * a
  532. def test_grad_refactor_3():
  533. assert C.grad_all(grad_refactor_3)(3) == (3,)
  534. def grad_refactor_4(a):
  535. """ if_test """
  536. if a > 3:
  537. return 3 * a
  538. return 0
  539. def test_grad_refactor_4():
  540. assert C.grad_all(grad_refactor_4)(4) == (3,)
  541. def grad_refactor_5(a):
  542. """ if_test """
  543. if a > 3:
  544. return 1
  545. return a
  546. def test_grad_refactor_5():
  547. assert C.grad_all(grad_refactor_5)(1) == (1,)
  548. def grad_refactor_6(a, b):
  549. """ if_test """
  550. if a > b:
  551. return 3 * a + b
  552. return 2 * b * a
  553. def test_grad_refactor_6():
  554. assert C.grad_all(grad_refactor_6)(3, 2) == (3, 1)
  555. def grad_refactor_while(x):
  556. """ grad_refactor_while """
  557. rval = x
  558. while rval < 4:
  559. rval = rval * rval
  560. return rval
  561. def test_grad_refactor_9():
  562. assert C.grad_all(grad_refactor_while)(3) == (6,)
  563. def grad_refactor__while_1(x):
  564. """ _while """
  565. ret = x * x
  566. i = 2
  567. while i <= 3:
  568. ret = ret * i
  569. i = i + 1
  570. return ret
  571. def test_grad_refactor_10():
  572. """ test_grad_while """
  573. assert C.grad_all(grad_refactor__while_1)(5) == (60,)
  574. def test_grad_refactor_11():
  575. class Net(nn.Cell):
  576. """ Net definition """
  577. def __init__(self):
  578. super(Net, self).__init__()
  579. def construct(self, x, y):
  580. return x * y * y
  581. net = Net()
  582. C.grad_all(net)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32)))
  583. def test_grad_refactor_12():
  584. class Net(nn.Cell):
  585. """ Net definition """
  586. def __init__(self):
  587. super(Net, self).__init__()
  588. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  589. def construct(self, x, y):
  590. return x * self.z * y
  591. net = Net()
  592. C.grad_all(net)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.zeros([2]).astype(np.float32)))
  593. def test_grad_refactor_13():
  594. class Net(nn.Cell):
  595. """ Net definition """
  596. def __init__(self):
  597. super(Net, self).__init__()
  598. self.z = Parameter(Tensor(np.ones([2]).astype(np.float32)), name='z')
  599. def construct(self, x, y):
  600. return x * self.z * y
  601. net = Net()
  602. weights = ParameterTuple(net.trainable_params())
  603. C.grad_by_list(net, weights)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.zeros([2]).astype(np.float32)))
  604. def grad_refactor_14(a, b):
  605. """ if_test """
  606. def inner1(x):
  607. return x * b
  608. def inner2(x):
  609. return a * b
  610. def inner3(x):
  611. if (x > 2):
  612. return a
  613. return b
  614. return inner1(b) + inner2(a) + inner3(a)
  615. def test_grad_refactor_14():
  616. assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9)
  617. class IfDeferInline(nn.Cell):
  618. def __init__(self, mul_size):
  619. super().__init__()
  620. self.mul_weight = Tensor(np.full(mul_size, 0.6, dtype=np.float32))
  621. self.mul = P.Mul()
  622. def construct(self, inputs):
  623. x = self.mul(inputs, self.mul_weight)
  624. if True:
  625. x = x
  626. return x
  627. def test_grad_if_defer_inline():
  628. """ test_grad_if_defer_inline """
  629. network = IfDeferInline([128, 96])
  630. network.add_flags(defer_inline=False)
  631. inp = Tensor(np.ones([128, 96]).astype(np.float32))
  632. grads = C.grad_all(network)(inp)
  633. assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),)
  634. def test_bprop_with_wrong_output_num():
  635. class BpropWithWrongOutputNum(PrimitiveWithInfer):
  636. @prim_attr_register
  637. def __init__(self):
  638. super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum')
  639. def __call__(self, x, y):
  640. return x
  641. def infer_shape(self, x_shape, yshape):
  642. return x_shape
  643. def infer_dtype(self, x_type, y_type):
  644. return x_type
  645. @bprop_getters.register(BpropWithWrongOutputNum)
  646. def get_bprop_with_wrong_output_num(self):
  647. """Generate bprop for BpropWithWrongOutputNum"""
  648. def bprop(x, y, out, dout):
  649. return (dout,)
  650. return bprop
  651. class BpropWithWrongOutputNumCell(nn.Cell):
  652. def __init__(self):
  653. super(BpropWithWrongOutputNumCell, self).__init__()
  654. def construct(self, x, y):
  655. return BpropWithWrongOutputNum()(x, y)
  656. with pytest.raises(TypeError):
  657. C.grad_all(BpropWithWrongOutputNumCell())(1, 2)
  658. def test_bprop_with_wrong_output_type():
  659. class BpropWithWrongOutputType(PrimitiveWithInfer):
  660. @prim_attr_register
  661. def __init__(self):
  662. super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType')
  663. def __call__(self, x):
  664. return x
  665. def infer_shape(self, x_shape):
  666. return x_shape
  667. def infer_dtype(self, x_type):
  668. return x_type
  669. @bprop_getters.register(BpropWithWrongOutputType)
  670. def get_bprop_with_wrong_output_type(self):
  671. """Generate bprop for BpropWithWrongOutputType"""
  672. def bprop(x, out, dout):
  673. return (1,)
  674. return bprop
  675. class BpropWithWrongOutputTypeCell(nn.Cell):
  676. def __init__(self):
  677. super(BpropWithWrongOutputTypeCell, self).__init__()
  678. def construct(self, x):
  679. return BpropWithWrongOutputType()(x)
  680. with pytest.raises(TypeError):
  681. C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
  682. def test_bprop_with_wrong_output_shape():
  683. class BpropWithWrongOutputShape(PrimitiveWithInfer):
  684. @prim_attr_register
  685. def __init__(self):
  686. super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape')
  687. def __call__(self, x):
  688. return x
  689. def infer_shape(self, x_shape):
  690. return x_shape
  691. def infer_dtype(self, x_type):
  692. return x_type
  693. @bprop_getters.register(BpropWithWrongOutputShape)
  694. def get_bprop_with_wrong_output_shape(self):
  695. """Generate bprop for BpropWithWrongOutputShape"""
  696. ones = Tensor(np.ones([2,]).astype(np.int32))
  697. def bprop(x, out, dout):
  698. return (ones,)
  699. return bprop
  700. class BpropWithWrongOutputShapeCell(nn.Cell):
  701. def __init__(self):
  702. super(BpropWithWrongOutputShapeCell, self).__init__()
  703. def construct(self, x):
  704. return BpropWithWrongOutputShape()(x)
  705. with pytest.raises(TypeError):
  706. C.grad_all(BpropWithWrongOutputShapeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))