You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_framstruct.py 23 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_framstruct """
  16. import numpy as np
  17. import pytest
  18. import mindspore as ms
  19. import mindspore.nn as nn
  20. from mindspore import context
  21. from mindspore.common import dtype as mstype
  22. from mindspore.common.parameter import Parameter, ParameterTuple
  23. from mindspore.common.tensor import Tensor
  24. from mindspore.ops import composite as C
  25. from mindspore.ops import operations as P
  26. from ..ut_filter import non_graph_engine
  27. from ....mindspore_test_framework.utils.check_gradient import (
  28. ms_function, check_jacobian, Tensor, NNGradChecker,
  29. OperationGradChecker, check_gradient, ScalarGradChecker)
  30. context.set_context(mode=context.PYNATIVE_MODE)
  31. def setup_module(module):
  32. context.set_context(mode=context.PYNATIVE_MODE)
  33. grad = C.GradOperation()
  34. grad_all = C.GradOperation(get_all=True)
  35. grad_by_list = C.GradOperation(get_by_list=True)
  36. grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
  37. @ms_function
  38. def while_upper_bound(upper):
  39. rval = 2
  40. while rval < upper:
  41. rval = rval * rval
  42. return rval
  43. def test_while_upper_bound():
  44. res = while_upper_bound(10)
  45. assert res == 16
  46. @ms_function
  47. def while_lower_bound(lower):
  48. """ t_while """
  49. rval = lower
  50. while rval < 100:
  51. rval = rval * rval
  52. return rval
  53. def test_while_lower_bound():
  54. res = while_lower_bound(2)
  55. assert res == 256
  56. @ms_function
  57. def dynamic_make_tuple(x, lower, upper):
  58. out = ()
  59. i = lower
  60. while i < upper:
  61. out = out + (x,)
  62. i = i + 1
  63. return out
  64. def test_dynamic_make_tuple():
  65. # Dynamicly recursively creating static type is invalid in mindspore, as mindspore is a static language.
  66. with pytest.raises(RuntimeError):
  67. dynamic_make_tuple(2, 1, 5)
  68. def test_make_tuple():
  69. # Staticly recursively creating static type is valid in mindspore.
  70. @ms_function
  71. def make_tuple(x):
  72. out = ()
  73. for i in range(3):
  74. out = out + (x,)
  75. return out
  76. res = make_tuple(5)
  77. assert res == (5, 5, 5)
  78. @ms_function
  79. def add(x, y):
  80. """ add """
  81. return x + y
  82. def mul(x, y):
  83. """ mul """
  84. return x * y
  85. def add_mul(x, y):
  86. """ add_mul """
  87. return (x + y) * y
  88. def mainf(x, y):
  89. """ mainf """
  90. return grad_all(mul)(x, y)
  91. def grad_add_mul(x, y):
  92. """ grad_add_mul """
  93. return grad_all(add_mul)(x, y)
  94. @ms_function
  95. def sub(x, y):
  96. """ sub """
  97. return x - y
  98. # pylint: disable=using-constant-test
  99. @ms_function
  100. def if_always_true(x):
  101. """ if_always_true """
  102. if True:
  103. return x
  104. else:
  105. return 0
  106. def test_add():
  107. """ test_add """
  108. res = add(2.5, 3)
  109. assert res == 5.5
  110. def test_sub():
  111. """ test_sub """
  112. res = sub(3.5, 3)
  113. assert res == 0.5
  114. @non_graph_engine
  115. def test_if_always_true():
  116. """ test_if_always_true """
  117. res = if_always_true(1)
  118. assert res == 1
  119. @non_graph_engine
  120. def test_f():
  121. """ test_f """
  122. res = mainf(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32))
  123. assert res == (2, 3)
  124. @non_graph_engine
  125. def test_grad_add_mul():
  126. """ test_grad_add_mul """
  127. res = grad_add_mul(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32))
  128. assert res == (2, 7)
  129. def f(x):
  130. if x > 0:
  131. return f(x - 1)
  132. return x
  133. @ms_function
  134. def list_subscript():
  135. """ list_subscript """
  136. x = [1, 2, 3]
  137. return x[0] * x[1]
  138. def test_list_subscript():
  139. """ test_list_subscript """
  140. res = list_subscript()
  141. assert res == 2
  142. @ms_function
  143. def ms_infer_for(xs, y):
  144. """ ms_infer_for """
  145. rval = y
  146. for x in xs:
  147. rval = rval + x
  148. return rval
  149. def test_infer_for():
  150. """ test_infer_for """
  151. t = (1, 2, 3)
  152. y = 4
  153. res = ms_infer_for(t, y)
  154. assert res == 10
  155. @ms_function
  156. def if_construct(a, b):
  157. z = a
  158. if a > b:
  159. z = a + b
  160. else:
  161. z = a * b
  162. if z > b:
  163. return z - a
  164. else:
  165. return a - b
  166. def test_if_construct():
  167. """ test_if_construct """
  168. res = if_construct(3, 6)
  169. assert res == 15
  170. @ms_function
  171. def if_scalar(a, b):
  172. """ if_abstract """
  173. if a:
  174. return a
  175. return b
  176. def test_if_scalar1():
  177. """ test_if_abstract """
  178. res = if_scalar(3, 6)
  179. assert res == 3
  180. def test_if_scalar2():
  181. """ test_if_abstract """
  182. res = if_scalar(0, 6)
  183. assert res == 6
  184. @ms_function
  185. def if_tensor(a, b):
  186. c = a
  187. if a < b:
  188. c = a + a
  189. if c < b:
  190. c = a + c
  191. else:
  192. c = a + b
  193. else:
  194. c = b + b
  195. out = c + c
  196. return out
  197. def test_if_tensor():
  198. res = if_tensor(Tensor(np.ones([1]).astype(np.int32)), Tensor(np.ones([1]).astype(np.int32)))
  199. assert res == Tensor(np.ones([1]).astype(np.int32) * 4)
  200. def rec(x):
  201. """ rec """
  202. if x > 0:
  203. return rec(x - 1)
  204. return x
  205. @ms_function
  206. def grad_rec(input_x):
  207. return grad(rec)(input_x)
  208. def test_grad_rec():
  209. """ test_grad_rec """
  210. res = grad_rec(3)
  211. assert res == 1
  212. def test_me_rec():
  213. """ test_me_rec """
  214. res = rec(10)
  215. assert res == 0
  216. def t2_while(x, y):
  217. out = y - x
  218. i = 0
  219. while i < 10:
  220. out = mul(x, y)
  221. i = i + 1
  222. return out
  223. def test_while2():
  224. res = t2_while(2, 3)
  225. assert res == 6
  226. def test_grad_while2():
  227. @ms_function
  228. def df_t2_while(input_x, input_y):
  229. return grad(t2_while)(input_x, input_y)
  230. assert df_t2_while(2, 3) == 3
  231. def if_test(a, b):
  232. """ if_test """
  233. if a > b:
  234. return 3 * a
  235. return 2 * b
  236. def grad_if(x, y):
  237. """ grad_if """
  238. return grad_all(if_test)(x, y)
  239. def test_grad_if():
  240. """ test_grad_if """
  241. assert grad_if(Tensor(5, dtype=ms.int32), Tensor(4, dtype=ms.int32)) == (3, 0)
  242. # While loop is not unrolled in forward and backward graphs.
  243. def test_dont_unroll_while():
  244. def dont_unroll_while(x, y):
  245. i = 2
  246. out = y - x
  247. while i < 10:
  248. out = mul(x, y)
  249. i = i + 1
  250. return out
  251. @ms_function()
  252. def invoke_while(x, y):
  253. return grad(dont_unroll_while)(x, y)
  254. res = invoke_while(2, 3)
  255. assert res == 3
  256. class ConvNet(nn.Cell):
  257. def __init__(self):
  258. super(ConvNet, self).__init__()
  259. out_channel = 16
  260. kernel_size = 3
  261. self.conv = P.Conv2D(out_channel,
  262. kernel_size,
  263. mode=1,
  264. pad_mode="pad",
  265. pad=0,
  266. stride=1,
  267. dilation=2,
  268. group=1)
  269. self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w')
  270. def construct(self, x):
  271. return self.conv(x, self.w)
  272. conv = ConvNet()
  273. c1 = Tensor([2], mstype.float32)
  274. c2 = Tensor([10], mstype.float32)
  275. c3 = Tensor([1], mstype.float32)
  276. @ms_function
  277. def t1_while(x, y, z):
  278. out = x
  279. i = c1
  280. while i < c2:
  281. out = out + conv(z)
  282. i = i + c3
  283. out = out + out
  284. return out
  285. def test_while_net():
  286. y = Tensor(np.ones([1, 3, 3, 4]).astype(np.float32))
  287. x = Tensor(np.ones([1, 16, 12, 12]).astype(np.float32))
  288. z = Tensor(np.ones([1, 16, 16, 16]).astype(np.float32))
  289. res = t1_while(x, y, z)
  290. assert np.all(res.asnumpy() == np.ones([1, 16, 12, 12]).astype(np.float32) * 2306.0)
  291. @ms_function
  292. def if_while(a, b, x, z):
  293. c = a
  294. i = c1
  295. out = x
  296. if a < b:
  297. c = a + a
  298. while i < c2:
  299. out = out + conv(z)
  300. i = i + c3
  301. else:
  302. c = b + b
  303. out = c + c
  304. return out
  305. def test_if_while():
  306. x = Tensor(np.random.randn(1, 16, 12, 12).astype(np.float32))
  307. z = Tensor(np.random.randn(1, 16, 16, 16).astype(np.float32))
  308. res = if_while(Tensor(np.ones([1]).astype(np.float32)), Tensor(np.ones([1]).astype(np.float32)), x, z)
  309. assert np.all(res.asnumpy() == np.ones([64, 10]).astype(np.float32) * 4.0)
  310. def _while(x):
  311. """ _while """
  312. ret = x * x
  313. i = 2
  314. while i <= 3:
  315. ret = ret * i
  316. i = i + 1
  317. return ret
  318. def grad_while(x):
  319. """ grad_while """
  320. return grad_all(_while)(x)
  321. def test_grad_while():
  322. """ test_grad_while """
  323. assert grad_while(Tensor(5, dtype=ms.int32)) == (60,)
  324. @ms_function
  325. def factorial(n):
  326. """ factorial """
  327. if n == 0:
  328. return 1
  329. return n * factorial(n - 1)
  330. def test_factorial():
  331. res = factorial(3)
  332. assert res == 6
  333. def test_grad_factorial():
  334. @ms_function
  335. def df_factorial(x):
  336. return grad(factorial)(x)
  337. assert df_factorial(3) == 11
  338. @ms_function
  339. def factorial2(n):
  340. """ factorial """
  341. if n != 0:
  342. return n * factorial2(n - 1)
  343. elif n == 1:
  344. return 1 * factorial2(n - 1)
  345. else:
  346. return 1
  347. def test_factorial2():
  348. res = factorial2(3)
  349. assert res == 6
  350. @ms_function
  351. def foo(n):
  352. if n <= 1:
  353. if n == 1:
  354. return foo(n - 1)
  355. else:
  356. return 1
  357. else:
  358. return foo(n - 1)
  359. def test_foo():
  360. res = foo(5)
  361. assert res == 1
  362. @ms_function
  363. def double_nested_loop(x):
  364. i = 0
  365. s = 0
  366. while i < x:
  367. j = 0
  368. i = i + 1
  369. while j < 3:
  370. j = j + 1
  371. s = s + j
  372. return s
  373. def test_nested_loop():
  374. res = double_nested_loop(3)
  375. assert res == 18
  376. @ms_function
  377. def double_nested_loop2(x):
  378. s = 0
  379. for i in range(x):
  380. for j in range(3):
  381. s = s + j
  382. return s
  383. def test_nested_loop2():
  384. res = double_nested_loop(1)
  385. assert res == 6
  386. def _for(x):
  387. """ _for """
  388. ret = x * x
  389. for i in (2, 3):
  390. ret = ret * i
  391. return ret
  392. @ms_function
  393. def grad_for(x):
  394. """ grad_for """
  395. return grad_all(_for)(x)
  396. def test_grad_for():
  397. """ test_grad_for """
  398. assert grad_for(5) == (60,)
  399. @ms_function
  400. def try_tail(x):
  401. """ try_tail """
  402. return C.tail(x)
  403. @non_graph_engine
  404. def test_tail():
  405. """ test_tail """
  406. try_tail((0, 1, 2, 3))
  407. @ms_function
  408. def zero_like_tensor(x):
  409. """ zero_like_tensor """
  410. return C.zeros_like(x)
  411. def test_zeros():
  412. """ test_zeros """
  413. x = Tensor(np.ones([2, 3]).astype(np.int32))
  414. res = zero_like_tensor(x)
  415. assert np.all(res.asnumpy() == np.zeros([2, 3]).astype(np.int32))
  416. @ms_function
  417. def arithmetic_simplify_01(x, y):
  418. """ arithmetic_simplify_01 """
  419. return C.zeros_like(x) * y
  420. def test_arithmetic_simplify_01():
  421. """ test_arithmetic_simplify_01 """
  422. x = Tensor(np.ones([2, 3]).astype(np.int32))
  423. y = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
  424. res = arithmetic_simplify_01(x, y)
  425. expect = np.zeros([2, 3]).astype(np.int32)
  426. assert np.all(res.asnumpy() == expect)
  427. @ms_function
  428. def arithmetic_simplify_02(x, y):
  429. """ arithmetic_simplify_02 """
  430. return C.ones_like(x) * y
  431. def test_arithmetic_simplify_02():
  432. """ test_arithmetic_simplify_02 """
  433. x = Tensor(np.ones([2, 3]).astype(np.int32))
  434. y = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
  435. res = arithmetic_simplify_02(x, y)
  436. expect = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)
  437. assert np.all(res.asnumpy() == expect)
  438. @ms_function
  439. def arithmetic_simplify_03(x, y):
  440. """ arithmetic_simplify_03 """
  441. return x * C.ones_like(y)
  442. def test_arithmetic_simplify_03():
  443. """ test_arithmetic_simplify_03 """
  444. x = Tensor(np.ones([2, 3]).astype(np.int32))
  445. y = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
  446. res = arithmetic_simplify_03(x, y)
  447. expect = np.ones([2, 3]).astype(np.int32)
  448. assert np.all(res.asnumpy() == expect)
  449. @ms_function
  450. def arithmetic_simplify_04(x):
  451. """ arithmetic_simplify_04 """
  452. return x + 0
  453. def test_arithmetic_simplify_04():
  454. """ test_arithmetic_simplify_04 """
  455. x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
  456. res = arithmetic_simplify_04(x)
  457. expect = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)
  458. assert np.all(res.asnumpy() == expect)
  459. @ms_function
  460. def arithmetic_simplify_05(x):
  461. """ arithmetic_simplify_05 """
  462. return x * 1
  463. def test_arithmetic_simplify_05():
  464. """ test_arithmetic_simplify_05 """
  465. x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
  466. res = arithmetic_simplify_05(x)
  467. expect = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)
  468. assert np.all(res.asnumpy() == expect)
  469. @ms_function
  470. def arithmetic_simplify_06(x):
  471. """ arithmetic_simplify_06 """
  472. return x * 2 * 5
  473. def test_arithmetic_simplify_06():
  474. """ test_arithmetic_simplify_06 """
  475. x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
  476. res = arithmetic_simplify_06(x)
  477. expect = np.array([[10, 20, 30], [40, 50, 60]]).astype(np.int32)
  478. assert np.all(res.asnumpy() == expect)
  479. @ms_function
  480. def arithmetic_simplify_07(x):
  481. """ arithmetic_simplify_07 """
  482. return (x + 1) * 2 * 5
  483. def test_arithmetic_simplify_07():
  484. """ test_arithmetic_simplify_07 """
  485. x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
  486. res = arithmetic_simplify_07(x)
  487. expect = np.array([[20, 30, 40], [50, 60, 70]]).astype(np.int32)
  488. assert np.all(res.asnumpy() == expect)
  489. @ms_function
  490. def arithmetic_simplify_08(x, y):
  491. """ arithmetic_simplify_08 """
  492. return 1 * x * 1 * 1 + 1 * 0 * 1 + 0 + y * 1
  493. def test_arithmetic_simplify_08():
  494. """ test_arithmetic_simplify_08 """
  495. x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
  496. y = Tensor(np.ones([2, 3]).astype(np.int32))
  497. res = arithmetic_simplify_08(x, y)
  498. expect = np.array([[2, 3, 4], [5, 6, 7]]).astype(np.int32)
  499. assert np.all(res.asnumpy() == expect)
  500. def test_ScalarGradChecker():
  501. """ test_ScalarGradChecker """
  502. def scalar_f(x, y):
  503. return x * y
  504. check_gradient(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, sampling_times=1)
  505. def test_GradCheckerPrimitive():
  506. """ test_GradCheckerPrimitive """
  507. matmul = P.MatMul()
  508. def prim_f(x, y):
  509. return matmul(x, y)
  510. check_gradient(prim_f, Tensor(np.array([[0.65, 0.8, 0.8]], np.float32)),
  511. Tensor(np.array([[0.1], [0.2], [-.1]], np.float32)),
  512. grad_checker_class=OperationGradChecker, sampling_times=2)
  513. def test_NNGradChecker():
  514. """ test_NNGradChecker """
  515. class Net(nn.Cell):
  516. """ Net definition """
  517. def __init__(self):
  518. super(Net, self).__init__()
  519. self.dense = nn.Dense(10, 10)
  520. def construct(self, x):
  521. out = self.dense(x)
  522. return out
  523. check_gradient(Net(), Tensor(np.random.rand(1, 10).astype(np.float32)),
  524. delta=1e-3,
  525. max_error=1e-3,
  526. grad_checker_class=NNGradChecker, sampling_times=3)
  527. def test_OperationGradChecker():
  528. """ test_OperationGradChecker """
  529. class Net(nn.Cell):
  530. """ Net definition """
  531. def __init__(self):
  532. super(Net, self).__init__()
  533. self.matmul = P.MatMul()
  534. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  535. def construct(self, x, y):
  536. x = x * self.z
  537. out = self.matmul(x, y)
  538. return out
  539. check_gradient(Net(), Tensor(np.array([[0.65, 0.8, 0.8]], np.float32)),
  540. Tensor(np.array([[0.1], [0.2], [-.1]], np.float32)), grad_checker_class=OperationGradChecker,
  541. input_selector=[1], sampling_times=2)
  542. def test_ScalarJacobianChecker():
  543. """ test_ScalarJacobianChecker """
  544. def scalar_f(x, y):
  545. return x * y
  546. check_jacobian(scalar_f, 1.0, 4.0, grad_checker_class=ScalarGradChecker, input_selector=[0])
  547. def test_OperationJacobianChecker():
  548. """ test_OperationJacobianChecker """
  549. class Net(nn.Cell):
  550. """ Net definition """
  551. def __init__(self):
  552. super(Net, self).__init__()
  553. self.matmul = P.MatMul()
  554. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  555. def construct(self, x, y):
  556. x = x * self.z
  557. out = self.matmul(x, y)
  558. return x, out
  559. check_jacobian(Net(), Tensor(np.array([[0.65, 0.8, 0.8], [0.1, 0.2, 0.3]], np.float32)),
  560. Tensor(np.array([[0.1, 0.3], [0.2, 0.2], [-.1, 0.4]], np.float32)),
  561. grad_checker_class=OperationGradChecker, input_selector=[0],
  562. output_selector=[0])
  563. def test_NNJacobianChecker():
  564. """ test_NNJacobianChecker """
  565. class Net(nn.Cell):
  566. """ Net definition """
  567. def __init__(self):
  568. super(Net, self).__init__()
  569. self.dense = nn.Dense(10, 10)
  570. def construct(self, x):
  571. out = self.dense(x)
  572. return out, x
  573. check_jacobian(Net(), Tensor(np.random.rand(1, 10).astype(np.float32)),
  574. delta=1e-3,
  575. max_error=1e-7,
  576. grad_checker_class=NNGradChecker,
  577. input_selector=[1],
  578. output_selector=[0])
  579. def multi_outputs(x, y):
  580. z = x + y
  581. return 2 * z, 2 * z
  582. def test_grad_multi_outputs():
  583. @ms_function
  584. def df_multi_outputs(x, y):
  585. return grad_all_with_sens(multi_outputs)(x, y, (1, 1))
  586. assert df_multi_outputs(2, 3) == (4, 4)
  587. @ms_function
  588. def while_sp(x, y, z):
  589. out = x
  590. i = c3
  591. while i < c2:
  592. out = mul(x, out)
  593. i = i + c3
  594. return out
  595. def test_while_sp():
  596. y = Tensor(np.ones([1, 3]).astype(np.float32))
  597. z = Tensor(np.ones([1, 3]).astype(np.float32))
  598. x = Tensor(np.ones([1, 3]).astype(np.float32) * 2.0)
  599. res = while_sp(x, y, z)
  600. assert np.all(res.asnumpy() == np.ones([1, 3]).astype(np.float32) * 1024.0)
  601. def grad_refactor_simple_1(x, y):
  602. """ add """
  603. return x * x + 2 * y
  604. def test_grad_refactor_simple_1():
  605. assert grad_all(grad_refactor_simple_1)(Tensor(2, dtype=ms.int32), Tensor(1, dtype=ms.int32)) == (4, 2)
  606. def grad_refactor_simple_2(x, y, z):
  607. """ add """
  608. return x * y + z + x * y * z + x + x * y
  609. def test_grad_refactor_simple_2():
  610. x = Tensor(2, dtype=ms.int32)
  611. y = Tensor(3, dtype=ms.int32)
  612. z = Tensor(0, dtype=ms.int32)
  613. assert grad_all(grad_refactor_simple_2)(x, y, z) == (7, 4, 7)
  614. def grad_refactor_1(a, b):
  615. """ if_test """
  616. def inner(x, y):
  617. return x * y
  618. return inner(a, b)
  619. def test_grad_refactor_1():
  620. assert grad_all(grad_refactor_1)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (3, 2)
  621. def grad_refactor_2(a, b):
  622. """ if_test """
  623. def inner(x):
  624. return x * b
  625. return inner(b) * inner(a)
  626. def test_grad_refactor_2():
  627. assert grad_all(grad_refactor_2)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (27, 54)
  628. def grad_refactor_3(a):
  629. """ if_test """
  630. if a > 3:
  631. return 0
  632. return 3 * a
  633. def test_grad_refactor_3():
  634. @ms_function
  635. def df_refactor_3(x):
  636. return grad_all(grad_refactor_3)(x)
  637. assert df_refactor_3(3) == (3,)
  638. def grad_refactor_4(a):
  639. """ if_test """
  640. if a > 3:
  641. return 3 * a
  642. return 0
  643. def test_grad_refactor_4():
  644. assert grad_all(grad_refactor_4)(Tensor(4, dtype=ms.int32)) == (3,)
  645. def grad_refactor_5(a):
  646. """ if_test """
  647. if a > 3:
  648. return 1
  649. return a
  650. def test_grad_refactor_5():
  651. @ms_function
  652. def df_refactor_5(x):
  653. return grad_all(grad_refactor_5)(x)
  654. assert df_refactor_5(1) == (1,)
  655. def grad_refactor_6(a, b):
  656. """ if_test """
  657. if a > b:
  658. return 3 * a + b
  659. return 2 * b * a
  660. def test_grad_refactor_6():
  661. assert grad_all(grad_refactor_6)(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) == (3, 1)
  662. def grad_refactor_while(x):
  663. """ grad_refactor_while """
  664. rval = x
  665. while rval < 4:
  666. rval = rval * rval
  667. return rval
  668. def test_grad_refactor_9():
  669. @ms_function
  670. def df_refactor_while(input_x):
  671. return grad_all(grad_refactor_while)(input_x)
  672. assert df_refactor_while(3) == (6,)
  673. def grad_refactor__while_1(x):
  674. """ _while """
  675. ret = x * x
  676. i = 2
  677. while i <= 3:
  678. ret = ret * i
  679. i = i + 1
  680. return ret
  681. def test_grad_refactor_10():
  682. """ test_grad_while """
  683. assert grad_all(grad_refactor__while_1)(Tensor(5, dtype=ms.int32)) == (60,)
  684. def test_grad_refactor_11():
  685. class Net(nn.Cell):
  686. """ Net definition """
  687. def __init__(self):
  688. super(Net, self).__init__()
  689. def construct(self, x, y):
  690. return x * y * y
  691. net = Net()
  692. grad_all(net)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.ones([2]).astype(np.float32)))
  693. def test_grad_refactor_12():
  694. class Net(nn.Cell):
  695. """ Net definition """
  696. def __init__(self):
  697. super(Net, self).__init__()
  698. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  699. def construct(self, x, y):
  700. return x * self.z * y
  701. net = Net()
  702. grad_all(net)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.zeros([2]).astype(np.float32)))
  703. def test_grad_refactor_13():
  704. class Net(nn.Cell):
  705. """ Net definition """
  706. def __init__(self):
  707. super(Net, self).__init__()
  708. self.z = Parameter(Tensor(np.ones([2]).astype(np.float32)), name='z')
  709. def construct(self, x, y):
  710. return x * self.z * y
  711. net = Net()
  712. weights = ParameterTuple(net.trainable_params())
  713. grad_by_list(net, weights)(Tensor(np.ones([2]).astype(np.float32)), Tensor(np.zeros([2]).astype(np.float32)))
  714. def grad_refactor_14(a, b):
  715. """ if_test """
  716. def inner1(x):
  717. return x * b
  718. def inner2(x):
  719. return a * b
  720. def inner3(x):
  721. if x > 2:
  722. return a
  723. return b
  724. return inner1(b) + inner2(a) + inner3(a)
  725. def test_grad_refactor_14():
  726. @ms_function
  727. def df_refactor_14(x, y):
  728. return grad_all(grad_refactor_14)(x, y)
  729. assert df_refactor_14(2, 3) == (3, 9)
  730. # pylint: disable=using-constant-test
  731. class IfDeferInline(nn.Cell):
  732. def __init__(self, mul_size):
  733. super().__init__()
  734. self.mul_weight = Tensor(np.full(mul_size, 0.6, dtype=np.float32))
  735. self.mul = P.Mul()
  736. def construct(self, inputs):
  737. x = self.mul(inputs, self.mul_weight)
  738. if True:
  739. x = x
  740. return x
  741. def test_grad_if_defer_inline():
  742. """ test_grad_if_defer_inline """
  743. network = IfDeferInline([128, 96])
  744. network.add_flags(defer_inline=False)
  745. inp = Tensor(np.ones([128, 96]).astype(np.float32))
  746. grads = grad_all(network)(inp)
  747. assert np.all(grads[0].asnumpy() == np.full([128, 96], 0.6, dtype=np.float32))
  748. def test_dict_const():
  749. class Net(nn.Cell):
  750. def __init__(self):
  751. super(Net, self).__init__()
  752. self.res = {'1': 10}
  753. def construct(self):
  754. return self.res
  755. Net()()