You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_auto_grad.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore.nn as nn
  17. import mindspore.ops as ops
  18. from mindspore import context
  19. from mindspore import Tensor
  20. from mindspore.ops import operations as P
  21. from mindspore.ops import composite as C
  22. from mindspore.common.parameter import Parameter, ParameterTuple
  23. grad_all = C.GradOperation(get_all=True)
  24. grad_by_list = C.GradOperation(get_by_list=True)
  25. class CropAndResizeNet(nn.Cell):
  26. def __init__(self, crop_size):
  27. super(CropAndResizeNet, self).__init__()
  28. self.crop_and_resize = P.CropAndResize()
  29. self.crop_size = crop_size
  30. def construct(self, x, boxes, box_indices):
  31. return self.crop_and_resize(x, boxes, box_indices, self.crop_size)
  32. def bprop(self, x, boxes, box_indices, out, dout):
  33. return x, boxes, box_indices
  34. class TestUserDefinedBpropNet(nn.Cell):
  35. def __init__(self, in_channel, out_channel):
  36. super(TestUserDefinedBpropNet, self).__init__()
  37. self.relu = nn.ReLU()
  38. self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=2, stride=1, has_bias=False,
  39. weight_init='ones', pad_mode='same')
  40. self.crop = CropAndResizeNet((10, 10))
  41. self.boxes = Tensor(np.ones((128, 4)).astype(np.float32))
  42. self.box_indices = Tensor(np.ones((128,)).astype(np.int32))
  43. def construct(self, x):
  44. x = self.relu(x)
  45. x = self.conv(x)
  46. x = self.crop(x, self.boxes, self.box_indices)
  47. return x
  48. class TestUserDefinedBpropGradNet(nn.Cell):
  49. def __init__(self, net):
  50. super(TestUserDefinedBpropGradNet, self).__init__()
  51. self.net = net
  52. def construct(self, x):
  53. return grad_all(self.net)(x)
  54. def test_user_defined_bprop():
  55. context.set_context(mode=context.GRAPH_MODE)
  56. net = TestUserDefinedBpropNet(3, 10)
  57. grad_net = TestUserDefinedBpropGradNet(net)
  58. x = Tensor(np.ones((128, 3, 12, 12)).astype(np.float32))
  59. grad_net(x)
  60. class TwoInputBPropOperator(nn.Cell):
  61. def __init__(self):
  62. super().__init__()
  63. self.op = P.Mul()
  64. self.add = P.Add()
  65. def construct(self, x, y):
  66. return self.op(x, y)
  67. def bprop(self, x, y, out, dout):
  68. return self.add(5, x), self.add(y, 9)
  69. class BPropOperatatorNet(nn.Cell):
  70. def __init__(self, mul_size):
  71. super().__init__()
  72. mul_np = np.full(mul_size, 0.1, dtype=np.float32)
  73. floordiv_np = np.full(mul_size, 0.1, dtype=np.float32)
  74. self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight")
  75. self.floordiv_weight = Parameter(Tensor(floordiv_np), name="floordiv_weight")
  76. self.mul = TwoInputBPropOperator()
  77. self.floor_div = P.FloorDiv()
  78. self.bn = nn.BatchNorm1d(num_features=96)
  79. def construct(self, inputs):
  80. x = self.mul(inputs, self.mul_weight)
  81. x = self.floor_div(x, self.floordiv_weight)
  82. x = self.bn(x)
  83. return x
  84. def test_user_defined_bprop_with_u():
  85. net = BPropOperatatorNet(mul_size=(128, 96))
  86. grad_net = TestUserDefinedBpropGradNet(net)
  87. x = Tensor(np.random.randn(128, 96).astype(np.float32))
  88. grad_net(x)
  89. class SinNet(nn.Cell):
  90. def __init__(self):
  91. super(SinNet, self).__init__()
  92. self.sin = ops.Sin()
  93. def construct(self, x):
  94. out = self.sin(x)
  95. return out
  96. class SinGrad(nn.Cell):
  97. def __init__(self, network):
  98. super(SinGrad, self).__init__()
  99. self.grad = ops.GradOperation()
  100. self.network = network
  101. def construct(self, x):
  102. gout = self.grad(self.network)(x)
  103. return gout
  104. class SinGradSec(nn.Cell):
  105. def __init__(self, network):
  106. super(SinGradSec, self).__init__()
  107. self.grad = ops.GradOperation()
  108. self.network = network
  109. def construct(self, x):
  110. gout = self.grad(self.network)(x)
  111. return gout
  112. def test_second_grad_with_j_primitive():
  113. context.set_context(mode=context.GRAPH_MODE)
  114. net = SinNet()
  115. first_grad = SinGrad(net)
  116. second_grad = SinGradSec(first_grad)
  117. x = Tensor(np.array([1.0], dtype=np.float32))
  118. second_grad(x)
  119. # A CNode being used as FV is MapMorphism after MapMorphism of call-site CNode;
  120. def test_ad_fv_cnode_order():
  121. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  122. class Net(nn.Cell):
  123. def __init__(self):
  124. super(Net, self).__init__()
  125. # cnode xay is not being MapMorphism when cnode second_level() is being MapMorphism and
  126. # BackPropagateFv as MapMorphism is started from output node and from left to right order.
  127. def construct(self, x, y):
  128. def first_level():
  129. xay = x + y
  130. def second_level():
  131. return xay
  132. return second_level() + xay
  133. return first_level()
  134. input_x = Tensor(np.array([1.0], dtype=np.float32))
  135. input_y = Tensor(np.array([2.0], dtype=np.float32))
  136. net = Net()
  137. net.add_flags_recursive(defer_inline=True)
  138. grad_net = grad_all(net)
  139. grad_net(input_x, input_y)
  140. # True and False branch of switch have different number of parameters.
  141. def test_if_branch_with_different_params():
  142. context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
  143. class Net(nn.Cell):
  144. def __init__(self):
  145. super(Net, self).__init__()
  146. self.weight1 = Parameter(Tensor(np.array([1.0], dtype=np.float32)), name="weight1")
  147. self.weight2 = Parameter(Tensor(np.array([2.0], dtype=np.float32)), name="weight2")
  148. def construct(self, idx, end, x):
  149. out = x
  150. if idx < end:
  151. out = out + self.weight1 * self.weight2
  152. else:
  153. out = out + self.weight1
  154. return out
  155. class GradNet(nn.Cell):
  156. def __init__(self, net):
  157. super(GradNet, self).__init__()
  158. self.net = net
  159. self.weights = ParameterTuple(net.trainable_params())
  160. def construct(self, idx, end, x):
  161. return grad_by_list(self.net, self.weights)(idx, end, x)
  162. idx = Tensor(np.array((0), dtype=np.int32))
  163. end = Tensor(np.array((3), dtype=np.int32))
  164. x = Tensor(np.array([2.0], dtype=np.float32))
  165. net = Net()
  166. grad_net = GradNet(net)
  167. grad_net(idx, end, x)
  168. # Only lift fv in scope of lift_top_func_graph other than all func_graphs inside manager.
  169. # Otherwise, "Illegal AnfNode for evaluating" may be reported
  170. # because weight1 in Net may use old_parameter other than replicated one.
  171. def test_limit_lift_fv_scope():
  172. context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
  173. class Net(nn.Cell):
  174. def __init__(self):
  175. super(Net, self).__init__()
  176. self.weight1 = Parameter(Tensor(np.array([1.0], dtype=np.float32)), name="weight1")
  177. def construct(self, x, y):
  178. def inner_add(a, b):
  179. return a + b
  180. out = inner_add(x, y) + self.weight1
  181. return out
  182. class GradNet(nn.Cell):
  183. def __init__(self, net):
  184. super(GradNet, self).__init__()
  185. self.net = net
  186. self.weights = ParameterTuple(net.trainable_params())
  187. def construct(self, x, y):
  188. def inner_grad_add(a, b):
  189. return a + b
  190. d_weight = grad_by_list(self.net, self.weights)(x, y)[0]
  191. d_out = inner_grad_add(d_weight, y)
  192. return d_out
  193. x = Tensor(np.array([2.0], dtype=np.float32))
  194. y = Tensor(np.array([2.0], dtype=np.float32))
  195. net = Net()
  196. net.add_flags_recursive(defer_inline=True)
  197. grad_net = GradNet(net)
  198. grad_net.add_flags_recursive(defer_inline=True)
  199. grad_net(x, y)
  200. def test_same_primal_used_by_multi_j():
  201. class Net(nn.Cell):
  202. def __init__(self):
  203. super(Net, self).__init__()
  204. def construct(self, x):
  205. return x
  206. class GradNet(nn.Cell):
  207. def __init__(self, net):
  208. super(GradNet, self).__init__()
  209. self.net = net
  210. self.grad = ops.GradOperation()
  211. def construct(self, x):
  212. out = self.net(x)
  213. gout = self.grad(self.net)(x)
  214. gout1 = self.grad(self.net)(x)
  215. return out, gout, gout1
  216. x = Tensor(np.array([1.0], dtype=np.float32))
  217. net = Net()
  218. grad = GradNet(net)
  219. grad(x)
  220. def test_same_primal_used_by_multi_j_with_monad1():
  221. class AdamNet(nn.Cell):
  222. def __init__(self, var, m, v):
  223. super(AdamNet, self).__init__()
  224. self.apply_adam = P.Adam()
  225. self.var = Parameter(var, name="var")
  226. self.m = Parameter(m, name="m")
  227. self.v = Parameter(v, name="v")
  228. def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
  229. self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
  230. return self.var
  231. class AdamGradNet(nn.Cell):
  232. def __init__(self, network):
  233. super(AdamGradNet, self).__init__()
  234. self.grad_fn = ops.GradOperation(sens_param=True)
  235. self.sens = [Tensor(np.ones([3, 3, 3]).astype(np.float32)), Tensor(np.ones([3, 3, 3]).astype(np.float32))]
  236. self.network = network
  237. def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
  238. out = self.network(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
  239. gout1 = self.grad_fn(self.network)(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[0])
  240. gout2 = self.grad_fn(self.network)(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[1])
  241. return out, gout1, gout2
  242. var = Tensor(np.ones([3, 3, 3]).astype(np.float32))
  243. m = Tensor(np.ones([3, 3, 3]).astype(np.float32))
  244. v = Tensor(np.ones([3, 3, 3]).astype(np.float32))
  245. beta1_power = Tensor(np.array([0.9], dtype=np.float32))
  246. beta2_power = Tensor(np.array([0.999], dtype=np.float32))
  247. lr = Tensor(np.array([0.001], dtype=np.float32))
  248. beta1 = Tensor(np.array([0.9], dtype=np.float32))
  249. beta2 = Tensor(np.array([0.999], dtype=np.float32))
  250. epsilon = Tensor(np.array([1e-8], dtype=np.float32))
  251. grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
  252. net = AdamNet(var, m, v)
  253. grad_net = AdamGradNet(net)
  254. grad_net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
  255. def test_same_primal_used_by_multi_j_with_monad2():
  256. class AdamNet(nn.Cell):
  257. def __init__(self, var, m, v):
  258. super(AdamNet, self).__init__()
  259. self.apply_adam = P.Adam()
  260. self.var = Parameter(var, name="var")
  261. self.m = Parameter(m, name="m")
  262. self.v = Parameter(v, name="v")
  263. def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
  264. self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
  265. return self.var
  266. class AdamGradNet(nn.Cell):
  267. def __init__(self, network):
  268. super(AdamGradNet, self).__init__()
  269. self.grad = ops.GradOperation(sens_param=True)
  270. self.sens = [Tensor(np.ones([3, 3, 3]).astype(np.float32)), Tensor(np.ones([3, 3, 3]).astype(np.float32))]
  271. self.network = network
  272. def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
  273. out = self.network(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
  274. grad_fn = self.grad(self.network)
  275. gout1 = grad_fn(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[0])
  276. gout2 = grad_fn(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[1])
  277. return out, gout1, gout2
  278. var = Tensor(np.ones([3, 3, 3]).astype(np.float32))
  279. m = Tensor(np.ones([3, 3, 3]).astype(np.float32))
  280. v = Tensor(np.ones([3, 3, 3]).astype(np.float32))
  281. beta1_power = Tensor(np.array([0.9], dtype=np.float32))
  282. beta2_power = Tensor(np.array([0.999], dtype=np.float32))
  283. lr = Tensor(np.array([0.001], dtype=np.float32))
  284. beta1 = Tensor(np.array([0.9], dtype=np.float32))
  285. beta2 = Tensor(np.array([0.999], dtype=np.float32))
  286. epsilon = Tensor(np.array([1e-8], dtype=np.float32))
  287. grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
  288. net = AdamNet(var, m, v)
  289. grad_net = AdamGradNet(net)
  290. grad_net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
  291. def test_grad_args_type_error1():
  292. class Net(nn.Cell):
  293. def __init__(self):
  294. super(Net, self).__init__()
  295. self.matmul = P.MatMul()
  296. def construct(self, x, y):
  297. out = self.matmul(x, y)
  298. return out
  299. class GradNetWrtX(nn.Cell):
  300. def __init__(self, net):
  301. super(GradNetWrtX, self).__init__()
  302. self.net = net
  303. self.grad_op = ops.GradOperation(get_all=2)
  304. def construct(self, x, y):
  305. gradient_function = self.grad_op(self.net)
  306. return gradient_function(x, y)
  307. x = Tensor(np.array([2.0], dtype=np.float32))
  308. y = Tensor(np.array([2.0], dtype=np.float32))
  309. try:
  310. GradNetWrtX(Net())(x, y)
  311. except TypeError as e:
  312. assert "For 'GradOperation', the 'get_all' should be bool, but got" in str(e)
  313. def test_grad_args_type_error2():
  314. class Net(nn.Cell):
  315. def __init__(self):
  316. super(Net, self).__init__()
  317. self.matmul = P.MatMul()
  318. def construct(self, x, y):
  319. out = self.matmul(x, y)
  320. return out
  321. class GradNetWrtX(nn.Cell):
  322. def __init__(self, net):
  323. super(GradNetWrtX, self).__init__()
  324. self.net = net
  325. self.grad_op = ops.GradOperation(get_by_list=2)
  326. def construct(self, x, y):
  327. gradient_function = self.grad_op(self.net)
  328. return gradient_function(x, y)
  329. x = Tensor(np.array([2.0], dtype=np.float32))
  330. y = Tensor(np.array([2.0], dtype=np.float32))
  331. try:
  332. GradNetWrtX(Net())(x, y)
  333. except TypeError as e:
  334. assert "For 'GradOperation', the 'get_by_list' should be bool, but got" in str(e)
  335. def test_grad_args_type_error3():
  336. class Net(nn.Cell):
  337. def __init__(self):
  338. super(Net, self).__init__()
  339. self.matmul = P.MatMul()
  340. def construct(self, x, y):
  341. out = self.matmul(x, y)
  342. return out
  343. class GradNetWrtX(nn.Cell):
  344. def __init__(self, net):
  345. super(GradNetWrtX, self).__init__()
  346. self.net = net
  347. self.grad_op = ops.GradOperation(sens_param=2)
  348. def construct(self, x, y):
  349. gradient_function = self.grad_op(self.net)
  350. return gradient_function(x, y)
  351. x = Tensor(np.array([2.0], dtype=np.float32))
  352. y = Tensor(np.array([2.0], dtype=np.float32))
  353. try:
  354. GradNetWrtX(Net())(x, y)
  355. except TypeError as e:
  356. assert "For 'GradOperation', the 'sens_param' should be bool, but got" in str(e)
  357. def test_grad_net_is_none():
  358. class Net(nn.Cell):
  359. def __init__(self):
  360. super(Net, self).__init__()
  361. self.add = P.Add()
  362. def construct(self, x, y):
  363. out = self.add(x, y)
  364. return out
  365. class GradNetWrtX(nn.Cell):
  366. def __init__(self, net):
  367. super(GradNetWrtX, self).__init__()
  368. self.net = P.Add()
  369. self.grad_op = ops.GradOperation()
  370. def construct(self, x, y):
  371. gradient_function = self.grad_op(None)
  372. return gradient_function(x, y)
  373. x = Tensor(np.array([2.0], dtype=np.float32))
  374. y = Tensor(np.array([2.0], dtype=np.float32))
  375. try:
  376. GradNetWrtX(Net())(x, y)
  377. except Exception as e:
  378. assert "'GradOperation' arg0 must be a 'Function' or 'Cell', but got" in str(e)
  379. def test_grad_missing_net():
  380. class Net(nn.Cell):
  381. def __init__(self):
  382. super(Net, self).__init__()
  383. self.add = P.Add()
  384. def construct(self, x, y):
  385. out = self.add(x, y)
  386. return out
  387. class GradNetWrtX(nn.Cell):
  388. def __init__(self, net):
  389. super(GradNetWrtX, self).__init__()
  390. self.net = net
  391. self.grad_op = ops.GradOperation()
  392. def construct(self, x, y):
  393. gradient_function = self.grad_op()
  394. return gradient_function(x, y)
  395. x = Tensor(np.array([2.0], dtype=np.float32))
  396. y = Tensor(np.array([2.0], dtype=np.float32))
  397. try:
  398. GradNetWrtX(Net())(x, y)
  399. except Exception as e:
  400. assert "'GradOperation' requires a forward network or function as an input, while the input is empty." in str(e)
  401. def test_user_defined_bprop_inputs_size_error():
  402. class BpropUserDefinedNet(nn.Cell):
  403. def __init__(self):
  404. super(BpropUserDefinedNet, self).__init__()
  405. self.zeros_like = P.ZerosLike()
  406. def construct(self, x, y):
  407. return x + y
  408. def bprop(self, out):
  409. return self.zeros_like(out), self.zeros_like(out)
  410. class BpropUserDefinedGradNet(nn.Cell):
  411. def __init__(self, net):
  412. super(BpropUserDefinedGradNet, self).__init__()
  413. self.net = net
  414. def construct(self, x, y):
  415. return grad_all(self.net)(x, y)
  416. net = BpropUserDefinedNet()
  417. grad_net = BpropUserDefinedGradNet(net)
  418. x = Tensor(np.array([2.0], dtype=np.float32))
  419. y = Tensor(np.array([2.0], dtype=np.float32))
  420. try:
  421. grad_net(x, y)
  422. except Exception as e:
  423. assert "The function 'bprop' of Primitive or Cell requires at least 2 params 'out' and 'dout', but got only"\
  424. in str(e)
  425. def test_user_defined_bprop_net_has_parameter():
  426. class BpropUserDefinedNet(nn.Cell):
  427. def __init__(self):
  428. super(BpropUserDefinedNet, self).__init__()
  429. self.zeros_like = P.ZerosLike()
  430. self.x = Parameter(Tensor(np.array([2.0], dtype=np.float32)), name="x")
  431. def construct(self, y):
  432. return self.x + y
  433. def bprop(self, y, out, dout):
  434. return (self.zeros_like(out),)
  435. class BpropUserDefinedGradNet(nn.Cell):
  436. def __init__(self, net):
  437. super(BpropUserDefinedGradNet, self).__init__()
  438. self.net = net
  439. def construct(self, y):
  440. return grad_all(self.net)(y)
  441. net = BpropUserDefinedNet()
  442. grad_net = BpropUserDefinedGradNet(net)
  443. y = Tensor(np.array([2.0], dtype=np.float32))
  444. try:
  445. grad_net(y)
  446. except Exception as e:
  447. assert "The Cell with user defined 'bprop' function in scope" in str(e)
  448. assert "does not support Parameter data type." in str(e)
  449. def test_user_defined_bprop_inputs_size_error1():
  450. class BpropUserDefinedNet(nn.Cell):
  451. def __init__(self):
  452. super(BpropUserDefinedNet, self).__init__()
  453. self.zeros_like = P.ZerosLike()
  454. def construct(self, x, y):
  455. return x + y
  456. def bprop(self, x, y, out):
  457. return self.zeros_like(out), self.zeros_like(out)
  458. class BpropUserDefinedGradNet(nn.Cell):
  459. def __init__(self, net):
  460. super(BpropUserDefinedGradNet, self).__init__()
  461. self.net = net
  462. def construct(self, x, y):
  463. return grad_all(self.net)(x, y)
  464. net = BpropUserDefinedNet()
  465. grad_net = BpropUserDefinedGradNet(net)
  466. x = Tensor(np.array([2.0], dtype=np.float32))
  467. y = Tensor(np.array([2.0], dtype=np.float32))
  468. try:
  469. grad_net(x, y)
  470. except TypeError as e:
  471. assert "The params of function 'bprop' of Primitive or Cell requires the forward inputs as well as the 'out' " \
  472. "and 'dout'." in str(e)
  473. def test_grad_hook():
  474. def var_hook_function(grad_out):
  475. assert grad_out[0].asnumpy().shape == (32, 120)
  476. class Net(nn.Cell):
  477. def __init__(self):
  478. super(Net, self).__init__()
  479. self.add = P.Add()
  480. self.hook = P.HookBackward(var_hook_function)
  481. def construct(self, x, y):
  482. x = self.hook(x)
  483. out = self.add(x, y)
  484. return out
  485. class GradNetWrtX(nn.Cell):
  486. def __init__(self, net):
  487. super(GradNetWrtX, self).__init__()
  488. self.net = net
  489. self.grad_op = ops.GradOperation()
  490. def construct(self, x, y):
  491. gradient_function = self.grad_op(self.net)
  492. return gradient_function(x, y)
  493. x = Tensor(np.array([2.0], dtype=np.float32))
  494. y = Tensor(np.array([2.0], dtype=np.float32))
  495. try:
  496. GradNetWrtX(Net())(x, y)
  497. except Exception as e:
  498. assert "The Primitive 'HookBackward' is not supported in graph mode, which is only supported in pynative " \
  499. "mode." in str(e)