You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_auto_grad.py 22 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore.nn as nn
  17. import mindspore.ops as ops
  18. from mindspore import context
  19. from mindspore import Tensor
  20. from mindspore.ops import operations as P
  21. from mindspore.ops import composite as C
  22. from mindspore.common.parameter import Parameter, ParameterTuple
  23. grad_all = C.GradOperation(get_all=True)
  24. grad_by_list = C.GradOperation(get_by_list=True)
  25. class CropAndResizeNet(nn.Cell):
  26. def __init__(self, crop_size):
  27. super(CropAndResizeNet, self).__init__()
  28. self.crop_and_resize = P.CropAndResize()
  29. self.crop_size = crop_size
  30. def construct(self, x, boxes, box_indices):
  31. return self.crop_and_resize(x, boxes, box_indices, self.crop_size)
  32. def bprop(self, x, boxes, box_indices, out, dout):
  33. return x, boxes, box_indices
  34. class TestUserDefinedBpropNet(nn.Cell):
  35. def __init__(self, in_channel, out_channel):
  36. super(TestUserDefinedBpropNet, self).__init__()
  37. self.relu = nn.ReLU()
  38. self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=2, stride=1, has_bias=False,
  39. weight_init='ones', pad_mode='same')
  40. self.crop = CropAndResizeNet((10, 10))
  41. self.boxes = Tensor(np.ones((128, 4)).astype(np.float32))
  42. self.box_indices = Tensor(np.ones((128,)).astype(np.int32))
  43. def construct(self, x):
  44. x = self.relu(x)
  45. x = self.conv(x)
  46. x = self.crop(x, self.boxes, self.box_indices)
  47. return x
  48. class TestUserDefinedBpropGradNet(nn.Cell):
  49. def __init__(self, net):
  50. super(TestUserDefinedBpropGradNet, self).__init__()
  51. self.net = net
  52. def construct(self, x):
  53. return grad_all(self.net)(x)
  54. def test_user_defined_bprop():
  55. context.set_context(mode=context.GRAPH_MODE)
  56. net = TestUserDefinedBpropNet(3, 10)
  57. grad_net = TestUserDefinedBpropGradNet(net)
  58. x = Tensor(np.ones((128, 3, 12, 12)).astype(np.float32))
  59. grad_net(x)
  60. class TwoInputBPropOperator(nn.Cell):
  61. def __init__(self):
  62. super().__init__()
  63. self.op = P.Mul()
  64. self.add = P.Add()
  65. def construct(self, x, y):
  66. return self.op(x, y)
  67. def bprop(self, x, y, out, dout):
  68. return self.add(5, x), self.add(y, 9)
  69. class BPropOperatatorNet(nn.Cell):
  70. def __init__(self, mul_size):
  71. super().__init__()
  72. mul_np = np.full(mul_size, 0.1, dtype=np.float32)
  73. floordiv_np = np.full(mul_size, 0.1, dtype=np.float32)
  74. self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight")
  75. self.floordiv_weight = Parameter(Tensor(floordiv_np), name="floordiv_weight")
  76. self.mul = TwoInputBPropOperator()
  77. self.floor_div = P.FloorDiv()
  78. self.bn = nn.BatchNorm1d(num_features=96)
  79. def construct(self, inputs):
  80. x = self.mul(inputs, self.mul_weight)
  81. x = self.floor_div(x, self.floordiv_weight)
  82. x = self.bn(x)
  83. return x
  84. def test_user_defined_bprop_with_u():
  85. net = BPropOperatatorNet(mul_size=(128, 96))
  86. grad_net = TestUserDefinedBpropGradNet(net)
  87. x = Tensor(np.random.randn(128, 96).astype(np.float32))
  88. grad_net(x)
  89. class SinNet(nn.Cell):
  90. def __init__(self):
  91. super(SinNet, self).__init__()
  92. self.sin = ops.Sin()
  93. def construct(self, x):
  94. out = self.sin(x)
  95. return out
  96. class SinGrad(nn.Cell):
  97. def __init__(self, network):
  98. super(SinGrad, self).__init__()
  99. self.grad = ops.GradOperation()
  100. self.network = network
  101. def construct(self, x):
  102. gout = self.grad(self.network)(x)
  103. return gout
  104. class SinGradSec(nn.Cell):
  105. def __init__(self, network):
  106. super(SinGradSec, self).__init__()
  107. self.grad = ops.GradOperation()
  108. self.network = network
  109. def construct(self, x):
  110. gout = self.grad(self.network)(x)
  111. return gout
  112. def test_second_grad_with_j_primitive():
  113. context.set_context(mode=context.GRAPH_MODE)
  114. net = SinNet()
  115. first_grad = SinGrad(net)
  116. second_grad = SinGradSec(first_grad)
  117. x = Tensor(np.array([1.0], dtype=np.float32))
  118. second_grad(x)
  119. # A CNode being used as FV is MapMorphism after MapMorphism of call-site CNode;
  120. def test_ad_fv_cnode_order():
  121. context.set_context(mode=context.GRAPH_MODE)
  122. class Net(nn.Cell):
  123. def __init__(self):
  124. super(Net, self).__init__()
  125. # cnode xay is not being MapMorphism when cnode second_level() is being MapMorphism and
  126. # BackPropagateFv as MapMorphism is started from output node and from left to right order.
  127. def construct(self, x, y):
  128. def first_level():
  129. xay = x + y
  130. def second_level():
  131. return xay
  132. return second_level() + xay
  133. return first_level()
  134. input_x = Tensor(np.array([1.0], dtype=np.float32))
  135. input_y = Tensor(np.array([2.0], dtype=np.float32))
  136. net = Net()
  137. net.add_flags_recursive(defer_inline=True)
  138. grad_net = grad_all(net)
  139. grad_net(input_x, input_y)
  140. # True and False branch of switch have different number of parameters.
  141. def test_if_branch_with_different_params():
  142. context.set_context(mode=context.GRAPH_MODE)
  143. class Net(nn.Cell):
  144. def __init__(self):
  145. super(Net, self).__init__()
  146. self.weight1 = Parameter(Tensor(np.array([1.0], dtype=np.float32)), name="weight1")
  147. self.weight2 = Parameter(Tensor(np.array([2.0], dtype=np.float32)), name="weight2")
  148. def construct(self, idx, end, x):
  149. out = x
  150. if idx < end:
  151. out = out + self.weight1 * self.weight2
  152. else:
  153. out = out + self.weight1
  154. return out
  155. class GradNet(nn.Cell):
  156. def __init__(self, net):
  157. super(GradNet, self).__init__()
  158. self.net = net
  159. self.weights = ParameterTuple(net.trainable_params())
  160. def construct(self, idx, end, x):
  161. return grad_by_list(self.net, self.weights)(idx, end, x)
  162. idx = Tensor(np.array((0), dtype=np.int32))
  163. end = Tensor(np.array((3), dtype=np.int32))
  164. x = Tensor(np.array([2.0], dtype=np.float32))
  165. net = Net()
  166. grad_net = GradNet(net)
  167. grad_net(idx, end, x)
  168. # Only lift fv in scope of lift_top_func_graph other than all func_graphs inside manager.
  169. # Otherwise, "Illegal AnfNode for evaluating" may be reported
  170. # because weight1 in Net may use old_parameter other than replicated one.
  171. def test_limit_lift_fv_scope():
  172. context.set_context(mode=context.GRAPH_MODE)
  173. class Net(nn.Cell):
  174. def __init__(self):
  175. super(Net, self).__init__()
  176. self.weight1 = Parameter(Tensor(np.array([1.0], dtype=np.float32)), name="weight1")
  177. def construct(self, x, y):
  178. def inner_add(a, b):
  179. return a + b
  180. out = inner_add(x, y) + self.weight1
  181. return out
  182. class GradNet(nn.Cell):
  183. def __init__(self, net):
  184. super(GradNet, self).__init__()
  185. self.net = net
  186. self.weights = ParameterTuple(net.trainable_params())
  187. def construct(self, x, y):
  188. def inner_grad_add(a, b):
  189. return a + b
  190. d_weight = grad_by_list(self.net, self.weights)(x, y)[0]
  191. d_out = inner_grad_add(d_weight, y)
  192. return d_out
  193. x = Tensor(np.array([2.0], dtype=np.float32))
  194. y = Tensor(np.array([2.0], dtype=np.float32))
  195. net = Net()
  196. net.add_flags_recursive(defer_inline=True)
  197. grad_net = GradNet(net)
  198. grad_net.add_flags_recursive(defer_inline=True)
  199. grad_net(x, y)
  200. def test_same_primal_used_by_multi_j():
  201. class Net(nn.Cell):
  202. def __init__(self):
  203. super(Net, self).__init__()
  204. def construct(self, x):
  205. return x
  206. class GradNet(nn.Cell):
  207. def __init__(self, net):
  208. super(GradNet, self).__init__()
  209. self.net = net
  210. self.grad = ops.GradOperation()
  211. def construct(self, x):
  212. out = self.net(x)
  213. gout = self.grad(self.net)(x)
  214. gout1 = self.grad(self.net)(x)
  215. return out, gout, gout1
  216. x = Tensor(np.array([1.0], dtype=np.float32))
  217. net = Net()
  218. grad = GradNet(net)
  219. grad(x)
  220. def test_same_primal_used_by_multi_j_with_monad1():
  221. context.set_context(mode=context.GRAPH_MODE)
  222. class AdamNet(nn.Cell):
  223. def __init__(self, var, m, v):
  224. super(AdamNet, self).__init__()
  225. self.apply_adam = P.Adam()
  226. self.var = Parameter(var, name="var")
  227. self.m = Parameter(m, name="m")
  228. self.v = Parameter(v, name="v")
  229. def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
  230. self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
  231. return self.var
  232. class AdamGradNet(nn.Cell):
  233. def __init__(self, network):
  234. super(AdamGradNet, self).__init__()
  235. self.grad_fn = ops.GradOperation(sens_param=True)
  236. self.sens = [Tensor(np.ones([3, 3, 3]).astype(np.float32)), Tensor(np.ones([3, 3, 3]).astype(np.float32))]
  237. self.network = network
  238. def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
  239. out = self.network(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
  240. gout1 = self.grad_fn(self.network)(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[0])
  241. gout2 = self.grad_fn(self.network)(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[1])
  242. return out, gout1, gout2
  243. var = Tensor(np.ones([3, 3, 3]).astype(np.float32))
  244. m = Tensor(np.ones([3, 3, 3]).astype(np.float32))
  245. v = Tensor(np.ones([3, 3, 3]).astype(np.float32))
  246. beta1_power = Tensor(np.array([0.9], dtype=np.float32))
  247. beta2_power = Tensor(np.array([0.999], dtype=np.float32))
  248. lr = Tensor(np.array([0.001], dtype=np.float32))
  249. beta1 = Tensor(np.array([0.9], dtype=np.float32))
  250. beta2 = Tensor(np.array([0.999], dtype=np.float32))
  251. epsilon = Tensor(np.array([1e-8], dtype=np.float32))
  252. grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
  253. net = AdamNet(var, m, v)
  254. grad_net = AdamGradNet(net)
  255. grad_net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
  256. def test_same_primal_used_by_multi_j_with_monad2():
  257. context.set_context(mode=context.GRAPH_MODE)
  258. class AdamNet(nn.Cell):
  259. def __init__(self, var, m, v):
  260. super(AdamNet, self).__init__()
  261. self.apply_adam = P.Adam()
  262. self.var = Parameter(var, name="var")
  263. self.m = Parameter(m, name="m")
  264. self.v = Parameter(v, name="v")
  265. def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
  266. self.apply_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
  267. return self.var
  268. class AdamGradNet(nn.Cell):
  269. def __init__(self, network):
  270. super(AdamGradNet, self).__init__()
  271. self.grad = ops.GradOperation(sens_param=True)
  272. self.sens = [Tensor(np.ones([3, 3, 3]).astype(np.float32)), Tensor(np.ones([3, 3, 3]).astype(np.float32))]
  273. self.network = network
  274. def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
  275. out = self.network(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
  276. grad_fn = self.grad(self.network)
  277. gout1 = grad_fn(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[0])
  278. gout2 = grad_fn(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, self.sens[1])
  279. return out, gout1, gout2
  280. var = Tensor(np.ones([3, 3, 3]).astype(np.float32))
  281. m = Tensor(np.ones([3, 3, 3]).astype(np.float32))
  282. v = Tensor(np.ones([3, 3, 3]).astype(np.float32))
  283. beta1_power = Tensor(np.array([0.9], dtype=np.float32))
  284. beta2_power = Tensor(np.array([0.999], dtype=np.float32))
  285. lr = Tensor(np.array([0.001], dtype=np.float32))
  286. beta1 = Tensor(np.array([0.9], dtype=np.float32))
  287. beta2 = Tensor(np.array([0.999], dtype=np.float32))
  288. epsilon = Tensor(np.array([1e-8], dtype=np.float32))
  289. grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32))
  290. net = AdamNet(var, m, v)
  291. grad_net = AdamGradNet(net)
  292. grad_net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
  293. def test_grad_args_type_error1():
  294. class Net(nn.Cell):
  295. def __init__(self):
  296. super(Net, self).__init__()
  297. self.matmul = P.MatMul()
  298. def construct(self, x, y):
  299. out = self.matmul(x, y)
  300. return out
  301. class GradNetWrtX(nn.Cell):
  302. def __init__(self, net):
  303. super(GradNetWrtX, self).__init__()
  304. self.net = net
  305. self.grad_op = ops.GradOperation(get_all=2)
  306. def construct(self, x, y):
  307. gradient_function = self.grad_op(self.net)
  308. return gradient_function(x, y)
  309. x = Tensor(np.array([2.0], dtype=np.float32))
  310. y = Tensor(np.array([2.0], dtype=np.float32))
  311. try:
  312. GradNetWrtX(Net())(x, y)
  313. except TypeError as e:
  314. assert "For 'GradOperation', the 'get_all' should be bool, but got" in str(e)
  315. def test_grad_args_type_error2():
  316. class Net(nn.Cell):
  317. def __init__(self):
  318. super(Net, self).__init__()
  319. self.matmul = P.MatMul()
  320. def construct(self, x, y):
  321. out = self.matmul(x, y)
  322. return out
  323. class GradNetWrtX(nn.Cell):
  324. def __init__(self, net):
  325. super(GradNetWrtX, self).__init__()
  326. self.net = net
  327. self.grad_op = ops.GradOperation(get_by_list=2)
  328. def construct(self, x, y):
  329. gradient_function = self.grad_op(self.net)
  330. return gradient_function(x, y)
  331. x = Tensor(np.array([2.0], dtype=np.float32))
  332. y = Tensor(np.array([2.0], dtype=np.float32))
  333. try:
  334. GradNetWrtX(Net())(x, y)
  335. except TypeError as e:
  336. assert "For 'GradOperation', the 'get_by_list' should be bool, but got" in str(e)
  337. def test_grad_args_type_error3():
  338. class Net(nn.Cell):
  339. def __init__(self):
  340. super(Net, self).__init__()
  341. self.matmul = P.MatMul()
  342. def construct(self, x, y):
  343. out = self.matmul(x, y)
  344. return out
  345. class GradNetWrtX(nn.Cell):
  346. def __init__(self, net):
  347. super(GradNetWrtX, self).__init__()
  348. self.net = net
  349. self.grad_op = ops.GradOperation(sens_param=2)
  350. def construct(self, x, y):
  351. gradient_function = self.grad_op(self.net)
  352. return gradient_function(x, y)
  353. x = Tensor(np.array([2.0], dtype=np.float32))
  354. y = Tensor(np.array([2.0], dtype=np.float32))
  355. try:
  356. GradNetWrtX(Net())(x, y)
  357. except TypeError as e:
  358. assert "For 'GradOperation', the 'sens_param' should be bool, but got" in str(e)
  359. def test_grad_net_is_none():
  360. class Net(nn.Cell):
  361. def __init__(self):
  362. super(Net, self).__init__()
  363. self.add = P.Add()
  364. def construct(self, x, y):
  365. out = self.add(x, y)
  366. return out
  367. class GradNetWrtX(nn.Cell):
  368. def __init__(self, net):
  369. super(GradNetWrtX, self).__init__()
  370. self.net = P.Add()
  371. self.grad_op = ops.GradOperation()
  372. def construct(self, x, y):
  373. gradient_function = self.grad_op(None)
  374. return gradient_function(x, y)
  375. x = Tensor(np.array([2.0], dtype=np.float32))
  376. y = Tensor(np.array([2.0], dtype=np.float32))
  377. try:
  378. GradNetWrtX(Net())(x, y)
  379. except Exception as e:
  380. assert "'GradOperation' arg0 must be a 'Function' or 'Cell', but got" in str(e)
  381. def test_grad_missing_net():
  382. class Net(nn.Cell):
  383. def __init__(self):
  384. super(Net, self).__init__()
  385. self.add = P.Add()
  386. def construct(self, x, y):
  387. out = self.add(x, y)
  388. return out
  389. class GradNetWrtX(nn.Cell):
  390. def __init__(self, net):
  391. super(GradNetWrtX, self).__init__()
  392. self.net = net
  393. self.grad_op = ops.GradOperation()
  394. def construct(self, x, y):
  395. gradient_function = self.grad_op()
  396. return gradient_function(x, y)
  397. x = Tensor(np.array([2.0], dtype=np.float32))
  398. y = Tensor(np.array([2.0], dtype=np.float32))
  399. try:
  400. GradNetWrtX(Net())(x, y)
  401. except Exception as e:
  402. assert "'GradOperation' requires a forward network or function as an input, while the input is empty." in str(e)
  403. def test_user_defined_bprop_inputs_size_error():
  404. class BpropUserDefinedNet(nn.Cell):
  405. def __init__(self):
  406. super(BpropUserDefinedNet, self).__init__()
  407. self.zeros_like = P.ZerosLike()
  408. def construct(self, x, y):
  409. return x + y
  410. def bprop(self, out):
  411. return self.zeros_like(out), self.zeros_like(out)
  412. class BpropUserDefinedGradNet(nn.Cell):
  413. def __init__(self, net):
  414. super(BpropUserDefinedGradNet, self).__init__()
  415. self.net = net
  416. def construct(self, x, y):
  417. return grad_all(self.net)(x, y)
  418. net = BpropUserDefinedNet()
  419. grad_net = BpropUserDefinedGradNet(net)
  420. x = Tensor(np.array([2.0], dtype=np.float32))
  421. y = Tensor(np.array([2.0], dtype=np.float32))
  422. try:
  423. grad_net(x, y)
  424. except Exception as e:
  425. assert "The function 'bprop' of Primitive or Cell requires at least 2 params 'out' and 'dout', but got only"\
  426. in str(e)
  427. def test_user_defined_bprop_net_has_parameter():
  428. class BpropUserDefinedNet(nn.Cell):
  429. def __init__(self):
  430. super(BpropUserDefinedNet, self).__init__()
  431. self.zeros_like = P.ZerosLike()
  432. self.x = Parameter(Tensor(np.array([2.0], dtype=np.float32)), name="x")
  433. def construct(self, y):
  434. return self.x + y
  435. def bprop(self, y, out, dout):
  436. return (self.zeros_like(out),)
  437. class BpropUserDefinedGradNet(nn.Cell):
  438. def __init__(self, net):
  439. super(BpropUserDefinedGradNet, self).__init__()
  440. self.net = net
  441. def construct(self, y):
  442. return grad_all(self.net)(y)
  443. net = BpropUserDefinedNet()
  444. grad_net = BpropUserDefinedGradNet(net)
  445. y = Tensor(np.array([2.0], dtype=np.float32))
  446. try:
  447. grad_net(y)
  448. except Exception as e:
  449. assert "The Cell with user defined 'bprop' function in scope" in str(e)
  450. assert "does not support Parameter data type." in str(e)
  451. def test_user_defined_bprop_inputs_size_error1():
  452. class BpropUserDefinedNet(nn.Cell):
  453. def __init__(self):
  454. super(BpropUserDefinedNet, self).__init__()
  455. self.zeros_like = P.ZerosLike()
  456. def construct(self, x, y):
  457. return x + y
  458. def bprop(self, x, y, out):
  459. return self.zeros_like(out), self.zeros_like(out)
  460. class BpropUserDefinedGradNet(nn.Cell):
  461. def __init__(self, net):
  462. super(BpropUserDefinedGradNet, self).__init__()
  463. self.net = net
  464. def construct(self, x, y):
  465. return grad_all(self.net)(x, y)
  466. net = BpropUserDefinedNet()
  467. grad_net = BpropUserDefinedGradNet(net)
  468. x = Tensor(np.array([2.0], dtype=np.float32))
  469. y = Tensor(np.array([2.0], dtype=np.float32))
  470. try:
  471. grad_net(x, y)
  472. except TypeError as e:
  473. assert "The params of function 'bprop' of Primitive or Cell requires the forward inputs as well as the 'out' " \
  474. "and 'dout'." in str(e)
  475. def test_grad_hook():
  476. def var_hook_function(grad_out):
  477. assert grad_out[0].asnumpy().shape == (32, 120)
  478. class Net(nn.Cell):
  479. def __init__(self):
  480. super(Net, self).__init__()
  481. self.add = P.Add()
  482. self.hook = P.HookBackward(var_hook_function)
  483. def construct(self, x, y):
  484. x = self.hook(x)
  485. out = self.add(x, y)
  486. return out
  487. class GradNetWrtX(nn.Cell):
  488. def __init__(self, net):
  489. super(GradNetWrtX, self).__init__()
  490. self.net = net
  491. self.grad_op = ops.GradOperation()
  492. def construct(self, x, y):
  493. gradient_function = self.grad_op(self.net)
  494. return gradient_function(x, y)
  495. x = Tensor(np.array([2.0], dtype=np.float32))
  496. y = Tensor(np.array([2.0], dtype=np.float32))
  497. try:
  498. GradNetWrtX(Net())(x, y)
  499. except Exception as e:
  500. assert "The Primitive 'HookBackward' is not supported in graph mode, which is only supported in pynative " \
  501. "mode." in str(e)