You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hook_function.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import pytest
  16. import numpy as np
  17. import mindspore.nn as nn
  18. import mindspore.ops.operations as P
  19. from mindspore.ops import composite as C
  20. from mindspore import context, Tensor
  21. from mindspore.common.api import ms_function
  22. grad_all = C.GradOperation(get_all=True)
  23. def var_hook_function(grad_out):
  24. print("grad:", grad_out)
  25. class GraphVarHook(nn.Cell):
  26. def __init__(self):
  27. super(GraphVarHook, self).__init__()
  28. self.relu = nn.ReLU()
  29. self.hook = P.HookBackward(var_hook_function)
  30. def construct(self, x):
  31. x = x + x
  32. x = x * x
  33. x = self.hook(x)
  34. x = self.relu(x)
  35. return x
  36. class MsFuncVarHook(nn.Cell):
  37. def __init__(self):
  38. super(MsFuncVarHook, self).__init__()
  39. self.relu = nn.ReLU()
  40. self.hook = P.HookBackward(var_hook_function)
  41. @ms_function
  42. def construct(self, x):
  43. x = x + x
  44. x = x * x
  45. x = self.hook(x)
  46. x = self.relu(x)
  47. return x
  48. @pytest.mark.level0
  49. @pytest.mark.platform_x86_cpu
  50. @pytest.mark.platform_arm_ascend_training
  51. @pytest.mark.platform_x86_ascend_training
  52. @pytest.mark.platform_x86_gpu_training
  53. @pytest.mark.env_onecard
  54. def test_var_hook_forward():
  55. input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
  56. context.set_context(mode=context.PYNATIVE_MODE)
  57. net1 = MsFuncVarHook()
  58. out1 = net1(input_x)
  59. context.set_context(mode=context.GRAPH_MODE)
  60. net2 = GraphVarHook()
  61. out2 = net2(input_x)
  62. assert np.allclose(out1.asnumpy(), out2.asnumpy(), 0.00001, 0.00001)
  63. @pytest.mark.level0
  64. @pytest.mark.platform_x86_cpu
  65. @pytest.mark.platform_arm_ascend_training
  66. @pytest.mark.platform_x86_ascend_training
  67. @pytest.mark.platform_x86_gpu_training
  68. @pytest.mark.env_onecard
  69. def test_var_hook_grad():
  70. input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
  71. context.set_context(mode=context.PYNATIVE_MODE)
  72. net1 = MsFuncVarHook()
  73. grad_out1 = grad_all(net1)(input_x)
  74. context.set_context(mode=context.GRAPH_MODE)
  75. net2 = GraphVarHook()
  76. grad_out2 = grad_all(net2)(input_x)
  77. assert np.allclose(grad_out1[0].asnumpy(), grad_out2[0].asnumpy(), 0.00001, 0.00001)
  78. def cell_hook_function(cell_id, grad_input, grad_output):
  79. print("cell id:", cell_id)
  80. print("grad input:", grad_input)
  81. print("grad output:", grad_output)
  82. class GraphCellHook(nn.Cell):
  83. def __init__(self):
  84. super(GraphCellHook, self).__init__()
  85. self.relu = nn.ReLU()
  86. self.relu.register_backward_hook(cell_hook_function)
  87. def construct(self, x):
  88. x = x + x
  89. x = x * x
  90. x = self.relu(x)
  91. return x
  92. class MsFuncCellHook(nn.Cell):
  93. def __init__(self):
  94. super(MsFuncCellHook, self).__init__()
  95. self.relu = nn.ReLU()
  96. self.relu.register_backward_hook(cell_hook_function)
  97. @ms_function
  98. def construct(self, x):
  99. x = x + x
  100. x = x * x
  101. x = self.relu(x)
  102. return x
  103. @pytest.mark.level0
  104. @pytest.mark.platform_x86_cpu
  105. @pytest.mark.platform_arm_ascend_training
  106. @pytest.mark.platform_x86_ascend_training
  107. @pytest.mark.platform_x86_gpu_training
  108. @pytest.mark.env_onecard
  109. def test_cell_hook_forward():
  110. input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
  111. context.set_context(mode=context.PYNATIVE_MODE)
  112. net1 = MsFuncCellHook()
  113. out1 = net1(input_x)
  114. context.set_context(mode=context.GRAPH_MODE)
  115. net2 = GraphCellHook()
  116. out2 = net2(input_x)
  117. assert np.allclose(out1.asnumpy(), out2.asnumpy(), 0.00001, 0.00001)
  118. @pytest.mark.level0
  119. @pytest.mark.platform_x86_cpu
  120. @pytest.mark.platform_arm_ascend_training
  121. @pytest.mark.platform_x86_ascend_training
  122. @pytest.mark.platform_x86_gpu_training
  123. @pytest.mark.env_onecard
  124. def test_cell_hook_grad():
  125. input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
  126. context.set_context(mode=context.PYNATIVE_MODE)
  127. net1 = MsFuncCellHook()
  128. grad_out1 = grad_all(net1)(input_x)
  129. context.set_context(mode=context.GRAPH_MODE)
  130. net2 = GraphCellHook()
  131. grad_out2 = grad_all(net2)(input_x)
  132. assert np.allclose(grad_out1[0].asnumpy(), grad_out2[0].asnumpy(), 0.00001, 0.00001)