You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_pynative_hook_grad.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_pynative_hook_grad """
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. import mindspore.ops.operations as P
  20. from mindspore.nn import Cell
  21. from mindspore import context
  22. from mindspore.common.tensor import Tensor
  23. from mindspore.ops.composite import GradOperation
  24. from mindspore.common import ParameterTuple
  25. def setup_module():
  26. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  27. class MetaFactory:
  28. def __init__(self):
  29. self.device_target = context.get_context('device_target')
  30. self.rank_size = None
  31. self.device_id = None
  32. self.global_rank_id = None
  33. class HookBase(MetaFactory):
  34. def __init__(self):
  35. super().__init__()
  36. MetaFactory.__init__(self)
  37. self.grad_input_list = []
  38. self.grad_output_list = []
  39. def ms_record_hook(self, cell_id, grad_input, grad_output):
  40. for grad in grad_input:
  41. self.grad_input_list.append(grad)
  42. for grad in grad_output:
  43. self.grad_output_list.append(grad)
  44. def ms_change_grad_double_hook(self, cell_id, grad_input, grad_output):
  45. y = Tensor(np.array([2.0]).astype(np.float32))
  46. mul = P.Mul()
  47. grad = grad_output[0]
  48. output = mul(grad, y)
  49. return output
  50. class FinalNet(nn.Cell, HookBase):
  51. def __init__(self):
  52. super().__init__()
  53. HookBase.__init__(self)
  54. self.conv = nn.Conv2d(1, 3, 3)
  55. self.relu = nn.ReLU()
  56. def construct(self, x, flag):
  57. if flag:
  58. x = self.conv(x)
  59. else:
  60. x = self.relu(x)
  61. return self.relu(x)
  62. class _Grad(Cell):
  63. def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
  64. super().__init__()
  65. self.network = network
  66. self.grad = grad
  67. self.sens_param = self.grad.sens_param
  68. self.wrt_params = wrt_params
  69. self.real_inputs_count = real_inputs_count
  70. if self.wrt_params:
  71. self.params = ParameterTuple(self.network.trainable_params())
  72. def construct(self, *inputs):
  73. if self.wrt_params:
  74. if self.real_inputs_count is None or self.sens_param is False:
  75. return self.grad(self.network, self.params)(*inputs)
  76. real_inputs = inputs[:self.real_inputs_count]
  77. sense_param_inputs = inputs[self.real_inputs_count:]
  78. return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
  79. if self.real_inputs_count is None or self.sens_param is False:
  80. return self.grad(self.network)(*inputs)
  81. real_inputs = inputs[:self.real_inputs_count]
  82. sense_param_inputs = inputs[self.real_inputs_count:]
  83. return self.grad(self.network)(*real_inputs, sense_param_inputs)
  84. class GradOfAllInputs(_Grad):
  85. def __init__(self, network, sens_param=True, real_inputs_count=None):
  86. super().__init__(grad=GradOperation(get_all=True, sens_param=sens_param),
  87. network=network, real_inputs_count=real_inputs_count)
  88. class MsMul4(nn.Cell):
  89. def construct(self, input_mul):
  90. out = input_mul * 2
  91. return out
  92. class MsMul(nn.Cell):
  93. def __init__(self):
  94. super().__init__()
  95. self.mul = P.Mul()
  96. def construct(self, x, y):
  97. x = self.mul(x, y)
  98. return x
  99. class MsAdd4(nn.Cell):
  100. def construct(self, input_add):
  101. out = input_add + 4
  102. return out
  103. class MsOneInputNet(nn.Cell, HookBase):
  104. def __init__(self):
  105. super().__init__()
  106. HookBase.__init__(self)
  107. self.add = MsAdd4()
  108. self.mul = MsMul4()
  109. self.relu = nn.ReLU()
  110. def construct(self, x):
  111. x = self.add(x)
  112. x = self.mul(x)
  113. out = self.relu(x)
  114. return out
  115. class MsMultiInputNet(nn.Cell, HookBase):
  116. def __init__(self):
  117. super().__init__()
  118. HookBase.__init__(self)
  119. self.mul1 = MsMul()
  120. self.mul2 = MsMul4()
  121. def construct(self, x, y):
  122. a = self.mul1(x, y)
  123. b = self.mul2(x)
  124. output = self.mul1(a, b)
  125. return output
  126. class MsNetWithParameter(nn.Cell, HookBase):
  127. def __init__(self):
  128. super().__init__()
  129. HookBase.__init__(self)
  130. self.conv1 = nn.Conv2d(2, 4, kernel_size=(1, 1), has_bias=True,
  131. weight_init=Tensor(np.ones([4, 2, 1, 1]).astype(np.float32)),
  132. bias_init=Tensor(np.ones([4]).astype(np.float32)))
  133. self.conv2 = nn.Conv2d(4, 8, kernel_size=(1, 1), has_bias=True,
  134. weight_init=Tensor(np.ones([8, 4, 1, 1]).astype(np.float32)),
  135. bias_init=Tensor(np.ones([8]).astype(np.float32)))
  136. def construct(self, x):
  137. x = self.conv1(x)
  138. output = self.conv2(x)
  139. return output
  140. class MsNetWithCellinCell(nn.Cell, HookBase):
  141. def __init__(self):
  142. super().__init__()
  143. HookBase.__init__(self)
  144. self.net1 = MsOneInputNet()
  145. self.mul = MsMul4()
  146. def construct(self, x):
  147. x = self.net1(x)
  148. output = self.mul(x)
  149. return output
  150. class MsSingleOpNetWithBprop(nn.Cell, HookBase):
  151. def __init__(self):
  152. super().__init__()
  153. HookBase.__init__(self)
  154. self.op = nn.ReLU()
  155. def construct(self, x):
  156. return self.op(x)
  157. def bprop(self, x, out, dout):
  158. y = Tensor(np.array([5.0]).astype(np.float32))
  159. mul = P.Mul()
  160. return mul(x, y)
  161. class MsNetHasBpropInChild(nn.Cell, HookBase):
  162. def __init__(self):
  163. super().__init__()
  164. HookBase.__init__(self)
  165. self.add = MsAdd4()
  166. self.bprop_net = MsSingleOpNetWithBprop()
  167. def construct(self, x):
  168. x = self.add(x)
  169. return self.bprop_net(x)
  170. class MsMultiOpNetWithBprop(nn.Cell, HookBase):
  171. def __init__(self):
  172. super().__init__()
  173. HookBase.__init__(self)
  174. self.mul = MsMul4()
  175. self.relu = nn.ReLU()
  176. def construct(self, x):
  177. x = self.mul(x)
  178. return self.relu(x)
  179. def bprop(self, x, out, dout):
  180. y = Tensor(np.array([5.0]).astype(np.float32))
  181. mul = P.Mul()
  182. return mul(x, y)
  183. def _count_unequal_element(data_expected, data_me, rtol, atol):
  184. assert data_expected.shape == data_me.shape
  185. total_count = len(data_expected.flatten())
  186. error = np.abs(data_expected - data_me)
  187. greater = np.greater(error, atol + np.abs(data_me)*rtol)
  188. loss_count = np.count_nonzero(greater)
  189. assert (loss_count/total_count) < rtol,\
  190. "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".\
  191. format(data_expected[greater], data_me[greater], error[greater])
  192. def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
  193. if np.any(np.isnan(data_expected)):
  194. assert np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan)
  195. elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan):
  196. _count_unequal_element(data_expected, data_me, rtol, atol)
  197. else:
  198. assert True
  199. @pytest.mark.level0
  200. @pytest.mark.platform_arm_ascend_training
  201. @pytest.mark.platform_x86_ascend_training
  202. @pytest.mark.env_onecard
  203. def test_pynative_hook_diff_hook():
  204. input_np = np.ones([1, 1, 224, 224]).astype(np.float32)
  205. ms_net = FinalNet()
  206. ms_net.set_grad()
  207. ms_net.conv.register_backward_hook(ms_net.ms_record_hook)
  208. ms_net.relu.register_backward_hook(ms_net.ms_change_grad_double_hook)
  209. input_ms = Tensor(input_np)
  210. out_ms = ms_net(input_ms, Tensor(1))
  211. grad_net = GradOfAllInputs(ms_net)
  212. grad_net.set_train()
  213. grad_net(input_ms, Tensor(1), out_ms)
  214. @pytest.mark.level0
  215. @pytest.mark.platform_arm_ascend_training
  216. @pytest.mark.platform_x86_ascend_training
  217. @pytest.mark.env_onecard
  218. def test_pynative_hook_outermost_cell_not_change_grad():
  219. input_np = np.ones([2, 2]).astype(np.float32)
  220. ms_net = MsOneInputNet()
  221. ms_net.set_grad()
  222. ms_net.register_backward_hook(ms_net.ms_record_hook)
  223. input_ms = Tensor(input_np)
  224. out_ms = ms_net(input_ms)
  225. grad_net = GradOfAllInputs(ms_net)
  226. grad_net.set_train()
  227. input_ms_grad = grad_net(input_ms, out_ms)
  228. #input grad
  229. input_torch_grad = np.array([[20, 20], [20, 20]])
  230. allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
  231. #hook record grad
  232. torch_net_grad_output = np.array([[10, 10], [10, 10]])
  233. torch_net_grad_input = np.array([[20, 20], [20, 20]])
  234. allclose_nparray(torch_net_grad_output, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001)
  235. allclose_nparray(torch_net_grad_input, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001)
  236. @pytest.mark.level0
  237. @pytest.mark.platform_arm_ascend_training
  238. @pytest.mark.platform_x86_ascend_training
  239. @pytest.mark.env_onecard
  240. def test_pynative_hook_all_cell_record_grad():
  241. input_np = np.ones([2, 2]).astype(np.float32)
  242. ms_net = MsOneInputNet()
  243. ms_net.set_grad()
  244. ms_net.mul.register_backward_hook(ms_net.ms_record_hook)
  245. ms_net.add.register_backward_hook(ms_net.ms_record_hook)
  246. ms_net.relu.register_backward_hook(ms_net.ms_record_hook)
  247. input_ms = Tensor(input_np)
  248. out_ms = ms_net(input_ms)
  249. grad_net = GradOfAllInputs(ms_net)
  250. grad_net.set_train()
  251. grad_net(input_ms, out_ms)
  252. torch_net_grad_input0 = np.array([[10, 10], [10, 10]])
  253. torch_net_grad_output0 = np.array([[10, 10], [10, 10]])
  254. torch_net_grad_input1 = np.array([[20, 20], [20, 20]])
  255. torch_net_grad_output1 = np.array([[10, 10], [10, 10]])
  256. allclose_nparray(torch_net_grad_input0, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001)
  257. allclose_nparray(torch_net_grad_output0, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001)
  258. allclose_nparray(torch_net_grad_input1, ms_net.grad_output_list[1].asnumpy(), 0.001, 0.001)
  259. allclose_nparray(torch_net_grad_output1, ms_net.grad_input_list[1].asnumpy(), 0.001, 0.001)
  260. torch_net_grad_input3 = np.array([[20, 20], [20, 20]])
  261. torch_net_grad_output2 = np.array([[20, 20], [20, 20]])
  262. allclose_nparray(torch_net_grad_input3, ms_net.grad_output_list[2].asnumpy(), 0.001, 0.001)
  263. allclose_nparray(torch_net_grad_output2, ms_net.grad_input_list[2].asnumpy(), 0.001, 0.001)
  264. @pytest.mark.level0
  265. @pytest.mark.platform_arm_ascend_training
  266. @pytest.mark.platform_x86_ascend_training
  267. @pytest.mark.env_onecard
  268. def test_pynative_hook_mul_change_input_grad():
  269. input_np = np.ones([2, 2]).astype(np.float32)
  270. ms_net = MsOneInputNet()
  271. ms_net.set_grad()
  272. ms_net.mul.register_backward_hook(ms_net.ms_change_grad_double_hook)
  273. input_ms = Tensor(input_np)
  274. out_ms = ms_net(input_ms)
  275. grad_net = GradOfAllInputs(ms_net)
  276. grad_net.set_train()
  277. input_ms_grad = grad_net(input_ms, out_ms)
  278. #input grad
  279. input_torch_grad = np.array([[40, 40], [40, 40]])
  280. allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
  281. @pytest.mark.level0
  282. @pytest.mark.platform_arm_ascend_training
  283. @pytest.mark.platform_x86_ascend_training
  284. @pytest.mark.env_onecard
  285. def test_pynative_hook_mul2_change_input_grad():
  286. input1_np = np.array([2.0, 3.0, 4.0]).astype(np.float32)
  287. input2_np = np.array([2.0, 3.0, 4.0]).astype(np.float32)
  288. ms_net = MsMultiInputNet()
  289. ms_net.set_grad()
  290. ms_net.mul2.register_backward_hook(ms_net.ms_change_grad_double_hook)
  291. input1_ms = Tensor(input1_np)
  292. input2_ms = Tensor(input2_np)
  293. out_ms = ms_net(input1_ms, input2_ms)
  294. grad_net = GradOfAllInputs(ms_net)
  295. grad_net.set_train()
  296. input_ms_grad = grad_net(input1_ms, input2_ms, out_ms)
  297. #input grad
  298. input1_torch_grad = np.array([384, 2916, 12288])
  299. input2_torch_grad = np.array([128, 972, 4096])
  300. allclose_nparray(input1_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
  301. allclose_nparray(input2_torch_grad, input_ms_grad[1].asnumpy(), 0.001, 0.001)
  302. @pytest.mark.level0
  303. @pytest.mark.platform_arm_ascend_training
  304. @pytest.mark.platform_x86_ascend_training
  305. @pytest.mark.env_onecard
  306. def test_pynative_hook_outermost_cell_change_grad():
  307. input_np = np.ones([2, 2]).astype(np.float32)
  308. ms_net = MsNetWithCellinCell()
  309. ms_net.set_grad()
  310. ms_net.register_backward_hook(ms_net.ms_change_grad_double_hook)
  311. input_ms = Tensor(input_np)
  312. out_ms = ms_net(input_ms)
  313. grad_net = GradOfAllInputs(ms_net)
  314. grad_net.set_train()
  315. input_ms_grad = grad_net(input_ms, out_ms)
  316. #input grad
  317. out_torch = np.array([[20, 20], [20, 20]])
  318. input_torch_grad = np.array([[160, 160], [160, 160]])
  319. allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001)
  320. allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
  321. @pytest.mark.level0
  322. @pytest.mark.platform_arm_ascend_training
  323. @pytest.mark.platform_x86_ascend_training
  324. @pytest.mark.env_onecard
  325. def test_pynative_hook_outermost_cell_record_grad():
  326. input_np = np.ones([2, 2]).astype(np.float32)
  327. ms_net = MsSingleOpNetWithBprop()
  328. ms_net.set_grad()
  329. ms_net.bprop_debug = True
  330. ms_net.register_backward_hook(ms_net.ms_record_hook)
  331. input_ms = Tensor(input_np)
  332. out_ms = ms_net(input_ms)
  333. grad_net = GradOfAllInputs(ms_net)
  334. grad_net.set_train()
  335. input_ms_grad = grad_net(input_ms, out_ms)
  336. if ms_net.grad_output_list or ms_net.grad_input_list:
  337. assert False
  338. #input grad
  339. out_torch = np.array([[1, 1], [1, 1]])
  340. input_torch_grad = np.array([[5, 5], [5, 5]])
  341. allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001)
  342. allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
  343. @pytest.mark.level0
  344. @pytest.mark.platform_arm_ascend_training
  345. @pytest.mark.platform_x86_ascend_training
  346. @pytest.mark.env_onecard
  347. def test_pynative_hook_bprop_outermost_cell_record_grad():
  348. input_np = np.ones([2, 2]).astype(np.float32)
  349. ms_net = MsNetHasBpropInChild()
  350. ms_net.set_grad()
  351. ms_net.bprop_net.bprop_debug = True
  352. ms_net.register_backward_hook(ms_net.ms_record_hook)
  353. input_ms = Tensor(input_np)
  354. out_ms = ms_net(input_ms)
  355. grad_net = GradOfAllInputs(ms_net)
  356. grad_net.set_train()
  357. input_ms_grad = grad_net(input_ms, out_ms)
  358. if len(ms_net.grad_output_list) != len(ms_net.grad_input_list) or not ms_net.grad_output_list:
  359. assert False
  360. #input grad
  361. out_torch = np.array([[5, 5], [5, 5]])
  362. input_torch_grad = np.array([[25, 25], [25, 25]])
  363. allclose_nparray(out_torch, out_ms.asnumpy(), 0.001, 0.001)
  364. allclose_nparray(input_torch_grad, input_ms_grad[0].asnumpy(), 0.001, 0.001)
  365. #hook record grad
  366. torch_net_grad_output = np.array([[5, 5], [5, 5]])
  367. torch_net_grad_input = np.array([[25, 25], [25, 25]])
  368. allclose_nparray(torch_net_grad_output, ms_net.grad_input_list[0].asnumpy(), 0.001, 0.001)
  369. allclose_nparray(torch_net_grad_input, ms_net.grad_output_list[0].asnumpy(), 0.001, 0.001)
  370. @pytest.mark.level0
  371. @pytest.mark.platform_arm_ascend_training
  372. @pytest.mark.platform_x86_ascend_training
  373. @pytest.mark.env_onecard
  374. def test_pynative_hook_child_cell_record_grad():
  375. input_np = np.ones([2, 2]).astype(np.float32)
  376. ms_net = MsMultiOpNetWithBprop()
  377. ms_net.set_grad()
  378. ms_net.bprop_debug = True
  379. ms_net.relu.register_backward_hook(ms_net.ms_record_hook)
  380. ms_net.mul.register_backward_hook(ms_net.ms_record_hook)
  381. input_ms = Tensor(input_np)
  382. out_ms = ms_net(input_ms)
  383. grad_net = GradOfAllInputs(ms_net)
  384. grad_net.set_train()
  385. grad_net(input_ms, out_ms)
  386. if ms_net.grad_output_list or ms_net.grad_input_list:
  387. assert False