You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_pynative_hook_grad.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_pynative_hook_grad """
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. import mindspore.ops.operations as P
  20. from mindspore.nn import Cell
  21. from mindspore import context
  22. from mindspore.common.tensor import Tensor
  23. from mindspore.ops.composite import GradOperation
  24. from mindspore.common import ParameterTuple
  25. class MetaFactory:
  26. def __init__(self):
  27. self.device_target = context.get_context('device_target')
  28. self.rank_size = None
  29. self.device_id = None
  30. self.global_rank_id = None
  31. class HookBase(MetaFactory):
  32. def __init__(self):
  33. super().__init__()
  34. MetaFactory.__init__(self)
  35. self.grad_input_list = []
  36. self.grad_output_list = []
  37. def ms_record_hook(self, cell_id, grad_input, grad_output):
  38. for grad in grad_input:
  39. self.grad_input_list.append(grad)
  40. for grad in grad_output:
  41. self.grad_output_list.append(grad)
  42. def ms_change_grad_double_hook(self, cell_id, grad_input, grad_output):
  43. y = Tensor(np.array([2.0]).astype(np.float32))
  44. mul = P.Mul()
  45. grad = grad_output[0]
  46. output = mul(grad, y)
  47. return output
  48. class FinalNet(nn.Cell, HookBase):
  49. def __init__(self):
  50. super().__init__()
  51. HookBase.__init__(self)
  52. self.conv = nn.Conv2d(1, 3, 3)
  53. self.relu = nn.ReLU()
  54. def construct(self, x, flag):
  55. if flag:
  56. x = self.conv(x)
  57. else:
  58. x = self.relu(x)
  59. return self.relu(x)
  60. class _Grad(Cell):
  61. def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
  62. super().__init__()
  63. self.network = network
  64. self.grad = grad
  65. self.sens_param = self.grad.sens_param
  66. self.wrt_params = wrt_params
  67. self.real_inputs_count = real_inputs_count
  68. if self.wrt_params:
  69. self.params = ParameterTuple(self.network.trainable_params())
  70. def construct(self, *inputs):
  71. if self.wrt_params:
  72. if self.real_inputs_count is None or self.sens_param is False:
  73. return self.grad(self.network, self.params)(*inputs)
  74. real_inputs = inputs[:self.real_inputs_count]
  75. sense_param_inputs = inputs[self.real_inputs_count:]
  76. return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
  77. if self.real_inputs_count is None or self.sens_param is False:
  78. return self.grad(self.network)(*inputs)
  79. real_inputs = inputs[:self.real_inputs_count]
  80. sense_param_inputs = inputs[self.real_inputs_count:]
  81. return self.grad(self.network)(*real_inputs, sense_param_inputs)
  82. class GradOfAllInputs(_Grad):
  83. def __init__(self, network, sens_param=True, real_inputs_count=None):
  84. super().__init__(grad=GradOperation(get_all=True, sens_param=sens_param),
  85. network=network, real_inputs_count=real_inputs_count)
  86. class MsMul4(nn.Cell):
  87. def construct(self, input_mul):
  88. out = input_mul * 2
  89. return out
  90. class MsMul(nn.Cell):
  91. def __init__(self):
  92. super().__init__()
  93. self.mul = P.Mul()
  94. def construct(self, x, y):
  95. x = self.mul(x, y)
  96. return x
  97. class MsAdd4(nn.Cell):
  98. def construct(self, input_add):
  99. out = input_add + 4
  100. return out
  101. class MsOneInputNet(nn.Cell, HookBase):
  102. def __init__(self):
  103. super().__init__()
  104. HookBase.__init__(self)
  105. self.add = MsAdd4()
  106. self.mul = MsMul4()
  107. self.relu = nn.ReLU()
  108. def construct(self, x):
  109. x = self.add(x)
  110. x = self.mul(x)
  111. out = self.relu(x)
  112. return out
  113. class MsMultiInputNet(nn.Cell, HookBase):
  114. def __init__(self):
  115. super().__init__()
  116. HookBase.__init__(self)
  117. self.mul1 = MsMul()
  118. self.mul2 = MsMul4()
  119. def construct(self, x, y):
  120. a = self.mul1(x, y)
  121. b = self.mul2(x)
  122. output = self.mul1(a, b)
  123. return output
  124. class MsNetWithParameter(nn.Cell, HookBase):
  125. def __init__(self):
  126. super().__init__()
  127. HookBase.__init__(self)
  128. self.conv1 = nn.Conv2d(2, 4, kernel_size=(1, 1), has_bias=True,
  129. weight_init=Tensor(np.ones([4, 2, 1, 1]).astype(np.float32)),
  130. bias_init=Tensor(np.ones([4]).astype(np.float32)))
  131. self.conv2 = nn.Conv2d(4, 8, kernel_size=(1, 1), has_bias=True,
  132. weight_init=Tensor(np.ones([8, 4, 1, 1]).astype(np.float32)),
  133. bias_init=Tensor(np.ones([8]).astype(np.float32)))
  134. def construct(self, x):
  135. x = self.conv1(x)
  136. output = self.conv2(x)
  137. return output
  138. class MsNetWithCellinCell(nn.Cell, HookBase):
  139. def __init__(self):
  140. super().__init__()
  141. HookBase.__init__(self)
  142. self.net1 = MsOneInputNet()
  143. self.mul = MsMul4()
  144. def construct(self, x):
  145. x = self.net1(x)
  146. output = self.mul(x)
  147. return output
  148. class MsSingleOpNetWithBprop(nn.Cell, HookBase):
  149. def __init__(self):
  150. super().__init__()
  151. HookBase.__init__(self)
  152. self.op = nn.ReLU()
  153. def construct(self, x):
  154. return self.op(x)
  155. def bprop(self, x, out, dout):
  156. y = Tensor(np.array([5.0]).astype(np.float32))
  157. mul = P.Mul()
  158. return mul(x, y)
  159. class MsNetHasBpropInChild(nn.Cell, HookBase):
  160. def __init__(self):
  161. super().__init__()
  162. HookBase.__init__(self)
  163. self.add = MsAdd4()
  164. self.bprop_net = MsSingleOpNetWithBprop()
  165. def construct(self, x):
  166. x = self.add(x)
  167. return self.bprop_net(x)
  168. class MsMultiOpNetWithBprop(nn.Cell, HookBase):
  169. def __init__(self):
  170. super().__init__()
  171. HookBase.__init__(self)
  172. self.mul = MsMul4()
  173. self.relu = nn.ReLU()
  174. def construct(self, x):
  175. x = self.mul(x)
  176. return self.relu(x)
  177. def bprop(self, x, out, dout):
  178. y = Tensor(np.array([5.0]).astype(np.float32))
  179. mul = P.Mul()
  180. return mul(x, y)
  181. def _count_unequal_element(data_expected, data_me, rtol, atol):
  182. assert data_expected.shape == data_me.shape
  183. total_count = len(data_expected.flatten())
  184. error = np.abs(data_expected - data_me)
  185. greater = np.greater(error, atol + np.abs(data_me)*rtol)
  186. loss_count = np.count_nonzero(greater)
  187. assert (loss_count/total_count) < rtol,\
  188. "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".\
  189. format(data_expected[greater], data_me[greater], error[greater])
  190. def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
  191. if np.any(np.isnan(data_expected)):
  192. assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
  193. elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
  194. _count_unequal_element(data_expected, data_me, rtol, atol)
  195. else:
  196. assert True
  197. def pynative_hook_diff_hook():
  198. input_np = np.ones([1, 1, 224, 224]).astype(np.float32)
  199. ms_net = FinalNet()
  200. ms_net.set_grad()
  201. ms_net.conv.register_backward_hook(ms_net.ms_record_hook)
  202. ms_net.relu.register_backward_hook(ms_net.ms_change_grad_double_hook)
  203. input_ms = Tensor(input_np)
  204. out_ms = ms_net(input_ms, Tensor(1))
  205. grad_net = GradOfAllInputs(ms_net)
  206. grad_net.set_train()
  207. grad_net(input_ms, Tensor(1), out_ms)
  208. def pynative_hook_outermost_cell_not_change_grad():
  209. input_np = np.ones([2, 2]).astype(np.float32)
  210. ms_net = MsOneInputNet()
  211. ms_net.set_grad()
  212. ms_net.register_backward_hook(ms_net.ms_record_hook)
  213. input_ms = Tensor(input_np)
  214. out_ms = ms_net(input_ms)
  215. grad_net = GradOfAllInputs(ms_net)
  216. grad_net.set_train()
  217. input_ms_grad = grad_net(input_ms, out_ms)
  218. #input grad
  219. input_torch_grad = np.array([[20, 20], [20, 20]])
  220. allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
  221. #hook record grad
  222. torch_net_grad_output = np.array([[10, 10], [10, 10]])
  223. torch_net_grad_input = np.array([[20, 20], [20, 20]])
  224. allclose_nparray(torch_net_grad_output, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001)
  225. allclose_nparray(torch_net_grad_input, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001)
  226. def pynative_hook_all_cell_record_grad():
  227. input_np = np.ones([2, 2]).astype(np.float32)
  228. ms_net = MsOneInputNet()
  229. ms_net.set_grad()
  230. ms_net.mul.register_backward_hook(ms_net.ms_record_hook)
  231. ms_net.add.register_backward_hook(ms_net.ms_record_hook)
  232. ms_net.relu.register_backward_hook(ms_net.ms_record_hook)
  233. input_ms = Tensor(input_np)
  234. out_ms = ms_net(input_ms)
  235. grad_net = GradOfAllInputs(ms_net)
  236. grad_net.set_train()
  237. grad_net(input_ms, out_ms)
  238. torch_net_grad_input0 = np.array([[10, 10], [10, 10]])
  239. torch_net_grad_output0 = np.array([[10, 10], [10, 10]])
  240. torch_net_grad_input1 = np.array([[20, 20], [20, 20]])
  241. torch_net_grad_output1 = np.array([[10, 10], [10, 10]])
  242. allclose_nparray(torch_net_grad_input0, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001)
  243. allclose_nparray(torch_net_grad_output0, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001)
  244. allclose_nparray(torch_net_grad_input1, ms_net.grad_output_list[1].asnumpy(), 0.001, 0.001)
  245. allclose_nparray(torch_net_grad_output1, ms_net.grad_input_list[1].asnumpy(), 0.001, 0.001)
  246. torch_net_grad_input3 = np.array([[20, 20], [20, 20]])
  247. torch_net_grad_output2 = np.array([[20, 20], [20, 20]])
  248. allclose_nparray(torch_net_grad_input3, ms_net.grad_output_list[2].asnumpy(), 0.001, 0.001)
  249. allclose_nparray(torch_net_grad_output2, ms_net.grad_input_list[2].asnumpy(), 0.001, 0.001)
  250. def pynative_hook_mul_change_input_grad():
  251. input_np = np.ones([2, 2]).astype(np.float32)
  252. ms_net = MsOneInputNet()
  253. ms_net.set_grad()
  254. ms_net.mul.register_backward_hook(ms_net.ms_change_grad_double_hook)
  255. input_ms = Tensor(input_np)
  256. out_ms = ms_net(input_ms)
  257. grad_net = GradOfAllInputs(ms_net)
  258. grad_net.set_train()
  259. input_ms_grad = grad_net(input_ms, out_ms)
  260. #input grad
  261. input_torch_grad = np.array([[40, 40], [40, 40]])
  262. allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
  263. def pynative_hook_mul2_change_input_grad():
  264. input1_np = np.array([2.0, 3.0, 4.0]).astype(np.float32)
  265. input2_np = np.array([2.0, 3.0, 4.0]).astype(np.float32)
  266. ms_net = MsMultiInputNet()
  267. ms_net.set_grad()
  268. ms_net.mul2.register_backward_hook(ms_net.ms_change_grad_double_hook)
  269. input1_ms = Tensor(input1_np)
  270. input2_ms = Tensor(input2_np)
  271. out_ms = ms_net(input1_ms, input2_ms)
  272. grad_net = GradOfAllInputs(ms_net)
  273. grad_net.set_train()
  274. input_ms_grad = grad_net(input1_ms, input2_ms, out_ms)
  275. #input grad
  276. input1_torch_grad = np.array([384, 2916, 12288])
  277. input2_torch_grad = np.array([128, 972, 4096])
  278. allclose_nparray(input1_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
  279. allclose_nparray(input2_torch_grad, input_ms_grad[1].asnumpy(), 0.001, 0.001)
  280. def pynative_hook_outermost_cell_change_grad():
  281. input_np = np.ones([2, 2]).astype(np.float32)
  282. ms_net = MsNetWithCellinCell()
  283. ms_net.set_grad()
  284. ms_net.register_backward_hook(ms_net.ms_change_grad_double_hook)
  285. input_ms = Tensor(input_np)
  286. out_ms = ms_net(input_ms)
  287. grad_net = GradOfAllInputs(ms_net)
  288. grad_net.set_train()
  289. input_ms_grad = grad_net(input_ms, out_ms)
  290. #input grad
  291. out_torch = np.array([[20, 20], [20, 20]])
  292. input_torch_grad = np.array([[160, 160], [160, 160]])
  293. allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001)
  294. allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
  295. def pynative_hook_outermost_cell_record_grad():
  296. input_np = np.ones([2, 2]).astype(np.float32)
  297. ms_net = MsSingleOpNetWithBprop()
  298. ms_net.set_grad()
  299. ms_net.bprop_debug = True
  300. ms_net.register_backward_hook(ms_net.ms_record_hook)
  301. input_ms = Tensor(input_np)
  302. out_ms = ms_net(input_ms)
  303. grad_net = GradOfAllInputs(ms_net)
  304. grad_net.set_train()
  305. input_ms_grad = grad_net(input_ms, out_ms)
  306. if ms_net.grad_output_list or ms_net.grad_input_list:
  307. assert False
  308. #input grad
  309. out_torch = np.array([[1, 1], [1, 1]])
  310. input_torch_grad = np.array([[5, 5], [5, 5]])
  311. allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001)
  312. allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
  313. def pynative_hook_bprop_outermost_cell_record_grad():
  314. input_np = np.ones([2, 2]).astype(np.float32)
  315. ms_net = MsNetHasBpropInChild()
  316. ms_net.set_grad()
  317. ms_net.bprop_net.bprop_debug = True
  318. ms_net.register_backward_hook(ms_net.ms_record_hook)
  319. input_ms = Tensor(input_np)
  320. out_ms = ms_net(input_ms)
  321. grad_net = GradOfAllInputs(ms_net)
  322. grad_net.set_train()
  323. input_ms_grad = grad_net(input_ms, out_ms)
  324. if len(ms_net.grad_output_list) != len(ms_net.grad_input_list) or not ms_net.grad_output_list:
  325. assert False
  326. #input grad
  327. out_torch = np.array([[5, 5], [5, 5]])
  328. input_torch_grad = np.array([[25, 25], [25, 25]])
  329. allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001)
  330. allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
  331. #hook record grad
  332. torch_net_grad_output = np.array([[5, 5], [5, 5]])
  333. torch_net_grad_input = np.array([[25, 25], [25, 25]])
  334. allclose_nparray(torch_net_grad_output, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001)
  335. allclose_nparray(torch_net_grad_input, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001)
  336. def pynative_hook_child_cell_record_grad():
  337. input_np = np.ones([2, 2]).astype(np.float32)
  338. ms_net = MsMultiOpNetWithBprop()
  339. ms_net.set_grad()
  340. ms_net.bprop_debug = True
  341. ms_net.relu.register_backward_hook(ms_net.ms_record_hook)
  342. ms_net.mul.register_backward_hook(ms_net.ms_record_hook)
  343. input_ms = Tensor(input_np)
  344. out_ms = ms_net(input_ms)
  345. grad_net = GradOfAllInputs(ms_net)
  346. grad_net.set_train()
  347. grad_net(input_ms, out_ms)
  348. if ms_net.grad_output_list or ms_net.grad_input_list:
  349. assert False
  350. @pytest.mark.level0
  351. @pytest.mark.platform_arm_ascend_training
  352. @pytest.mark.platform_x86_ascend_training
  353. @pytest.mark.env_onecard
  354. def test_pynative_hook_diff_hook_ascend():
  355. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  356. pynative_hook_diff_hook()
  357. @pytest.mark.level0
  358. @pytest.mark.platform_x86_gpu_training
  359. @pytest.mark.env_onecard
  360. def test_pynative_hook_diff_hook_gpu():
  361. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  362. pynative_hook_diff_hook()
  363. @pytest.mark.level0
  364. @pytest.mark.platform_arm_ascend_training
  365. @pytest.mark.platform_x86_ascend_training
  366. @pytest.mark.env_onecard
  367. def test_pynative_hook_outermost_cell_not_change_grad_ascend():
  368. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  369. pynative_hook_outermost_cell_not_change_grad()
  370. @pytest.mark.level0
  371. @pytest.mark.platform_x86_gpu_training
  372. @pytest.mark.env_onecard
  373. def test_pynative_hook_outermost_cell_not_change_grad_gpu():
  374. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  375. pynative_hook_outermost_cell_not_change_grad()
  376. @pytest.mark.level0
  377. @pytest.mark.platform_arm_ascend_training
  378. @pytest.mark.platform_x86_ascend_training
  379. @pytest.mark.env_onecard
  380. def test_pynative_hook_all_cell_record_grad_ascend():
  381. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  382. pynative_hook_all_cell_record_grad()
  383. @pytest.mark.level0
  384. @pytest.mark.platform_x86_gpu_training
  385. @pytest.mark.env_onecard
  386. def test_pynative_hook_all_cell_record_grad_gpu():
  387. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  388. pynative_hook_all_cell_record_grad()
  389. @pytest.mark.level0
  390. @pytest.mark.platform_arm_ascend_training
  391. @pytest.mark.platform_x86_ascend_training
  392. @pytest.mark.env_onecard
  393. def test_pynative_hook_mul_change_input_grad_ascend():
  394. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  395. pynative_hook_mul_change_input_grad()
  396. @pytest.mark.level0
  397. @pytest.mark.platform_x86_gpu_training
  398. @pytest.mark.env_onecard
  399. def test_pynative_hook_mul_change_input_grad_gpu():
  400. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  401. pynative_hook_mul_change_input_grad()
  402. @pytest.mark.level0
  403. @pytest.mark.platform_arm_ascend_training
  404. @pytest.mark.platform_x86_ascend_training
  405. @pytest.mark.env_onecard
  406. def test_pynative_hook_mul2_change_input_grad_ascend():
  407. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  408. pynative_hook_mul2_change_input_grad()
  409. @pytest.mark.level0
  410. @pytest.mark.platform_x86_gpu_training
  411. @pytest.mark.env_onecard
  412. def test_pynative_hook_mul2_change_input_grad_gpu():
  413. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  414. pynative_hook_mul2_change_input_grad()
  415. @pytest.mark.level0
  416. @pytest.mark.platform_arm_ascend_training
  417. @pytest.mark.platform_x86_ascend_training
  418. @pytest.mark.env_onecard
  419. def test_pynative_hook_outermost_cell_change_grad_ascend():
  420. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  421. pynative_hook_outermost_cell_change_grad()
  422. @pytest.mark.level0
  423. @pytest.mark.platform_x86_gpu_training
  424. @pytest.mark.env_onecard
  425. def test_pynative_hook_outermost_cell_change_grad_gpu():
  426. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  427. pynative_hook_outermost_cell_change_grad()
  428. @pytest.mark.level0
  429. @pytest.mark.platform_arm_ascend_training
  430. @pytest.mark.platform_x86_ascend_training
  431. @pytest.mark.env_onecard
  432. def test_pynative_hook_outermost_cell_record_grad_ascend():
  433. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  434. pynative_hook_outermost_cell_record_grad()
  435. @pytest.mark.level0
  436. @pytest.mark.platform_x86_gpu_training
  437. @pytest.mark.env_onecard
  438. def test_pynative_hook_outermost_cell_record_grad_gpu():
  439. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  440. pynative_hook_outermost_cell_record_grad()
  441. @pytest.mark.level0
  442. @pytest.mark.platform_arm_ascend_training
  443. @pytest.mark.platform_x86_ascend_training
  444. @pytest.mark.env_onecard
  445. def test_pynative_hook_bprop_outermost_cell_record_grad_ascend():
  446. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  447. pynative_hook_bprop_outermost_cell_record_grad()
  448. @pytest.mark.level0
  449. @pytest.mark.platform_x86_gpu_training
  450. @pytest.mark.env_onecard
  451. def test_pynative_hook_bprop_outermost_cell_record_grad_gpu():
  452. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  453. pynative_hook_bprop_outermost_cell_record_grad()
  454. @pytest.mark.level0
  455. @pytest.mark.platform_arm_ascend_training
  456. @pytest.mark.platform_x86_ascend_training
  457. @pytest.mark.env_onecard
  458. def test_pynative_hook_child_cell_record_grad_ascend():
  459. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  460. pynative_hook_child_cell_record_grad()
  461. @pytest.mark.level0
  462. @pytest.mark.platform_x86_gpu_training
  463. @pytest.mark.env_onecard
  464. def test_pynative_hook_child_cell_record_grad_gpu():
  465. context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
  466. pynative_hook_child_cell_record_grad()