You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_multi_grad.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import pytest
  16. import numpy as np
  17. from mindspore import context, nn, Tensor, Parameter, ParameterTuple
  18. from mindspore.common import dtype as mstype
  19. from mindspore.ops import composite as C
  20. @pytest.fixture(scope="module", autouse=True)
  21. def setup_teardown():
  22. context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
  23. yield
  24. context.set_context(mode=context.GRAPH_MODE)
  25. class _Grad(nn.Cell):
  26. def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
  27. super().__init__()
  28. self.network = network
  29. self.grad = grad
  30. self.sens_param = self.grad.sens_param
  31. self.wrt_params = wrt_params
  32. self.real_inputs_count = real_inputs_count
  33. if self.wrt_params:
  34. self.params = ParameterTuple(self.network.trainable_params())
  35. def construct(self, *inputs):
  36. if self.wrt_params:
  37. if self.real_inputs_count is None or self.sens_param is False:
  38. return self.grad(self.network, self.params)(*inputs)
  39. real_inputs = inputs[:self.real_inputs_count]
  40. sense_param_inputs = inputs[self.real_inputs_count:]
  41. return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
  42. if self.real_inputs_count is None or self.sens_param is False:
  43. return self.grad(self.network)(*inputs)
  44. real_inputs = inputs[:self.real_inputs_count]
  45. sense_param_inputs = inputs[self.real_inputs_count:]
  46. return self.grad(self.network)(*real_inputs, sense_param_inputs)
  47. class GradOfFirstInput(_Grad):
  48. """
  49. get grad of first input
  50. """
  51. def __init__(self, network, sens_param=True, real_inputs_count=None):
  52. super().__init__(grad=C.GradOperation(sens_param=sens_param),
  53. network=network, real_inputs_count=real_inputs_count)
  54. class GradOfAllInputs(_Grad):
  55. """
  56. get grad of first input
  57. """
  58. def __init__(self, network, sens_param=True, real_inputs_count=None):
  59. super().__init__(grad=C.GradOperation(get_all=True, sens_param=sens_param),
  60. network=network, real_inputs_count=real_inputs_count)
  61. def test_multi_grad():
  62. class ForwardNetMul(nn.Cell):
  63. def __init__(self):
  64. super().__init__()
  65. def construct(self, x, y):
  66. a = x * x
  67. b = y * y
  68. return a * b
  69. class ForwardNetAdd(nn.Cell):
  70. def __init__(self):
  71. super().__init__()
  72. def construct(self, x, y):
  73. a = x + x + x
  74. b = y + y
  75. return a * b
  76. mulnet = ForwardNetMul()
  77. addnet = ForwardNetAdd()
  78. x = Tensor(np.ones([32]), dtype=mstype.float32)
  79. y = Tensor(np.ones([32])*2, dtype=mstype.float32)
  80. sens = Tensor(np.ones([32]), dtype=mstype.float32)
  81. mulnet.set_grad()
  82. addnet.set_grad()
  83. out1 = mulnet(x, y)
  84. out2 = addnet(x, y)
  85. grad_mul = GradOfAllInputs(mulnet)
  86. grad_add = GradOfAllInputs(addnet)
  87. grad_mul(x, y, sens)
  88. grad_add(x, y, sens)
  89. def test_multi_same_grad():
  90. class ForwardNetMul(nn.Cell):
  91. def __init__(self):
  92. super().__init__()
  93. def construct(self, x, y):
  94. a = x * x
  95. b = y * y
  96. return a * b
  97. class ForwardNetAdd(nn.Cell):
  98. def __init__(self):
  99. super().__init__()
  100. def construct(self, x, y):
  101. a = x*3
  102. b = y*2
  103. return a + b
  104. mulnet = ForwardNetMul()
  105. addnet = ForwardNetAdd()
  106. x = Tensor(np.ones([32]), dtype=mstype.float32)
  107. y = Tensor(np.ones([32]), dtype=mstype.float32)
  108. sens = Tensor(np.ones([32]), dtype=mstype.float32)
  109. mulnet.set_grad()
  110. addnet.set_grad()
  111. out1 = mulnet(x, y)
  112. out2 = addnet(x, y)
  113. grad_mul = GradOfAllInputs(mulnet)
  114. grad_add = GradOfFirstInput(mulnet)
  115. grad_mul(x, y, sens)
  116. grad_add(x, y, sens)
  117. def test_net_inner_grad():
  118. class ForwardNetMul(nn.Cell):
  119. def __init__(self):
  120. super().__init__()
  121. def construct(self, x, y):
  122. a = x * x
  123. b = y * y
  124. return a * b
  125. class ForwardNetAdd(nn.Cell):
  126. def __init__(self, net):
  127. super().__init__()
  128. self.net = net
  129. def construct(self, x, y):
  130. a = x + x
  131. b = y + y
  132. res = self.net(a, b)
  133. return res
  134. mulnet = ForwardNetMul()
  135. addnet = ForwardNetAdd(mulnet)
  136. x = Tensor(np.ones([32]), dtype=mstype.float32)
  137. y = Tensor(np.ones([32]), dtype=mstype.float32)
  138. sens = Tensor(np.ones([32]), dtype=mstype.float32)
  139. mulnet.set_grad()
  140. addnet.set_grad()
  141. out1 = mulnet(x, y)
  142. out2 = addnet(x, y)
  143. grad_mul = GradOfAllInputs(addnet)
  144. grad_add = GradOfAllInputs(mulnet)
  145. grad_mul(x, y, sens)
  146. grad_add(x, y, sens)
  147. def test_net_inner_first_run_grad():
  148. class ForwardNetMul(nn.Cell):
  149. def __init__(self):
  150. super().__init__()
  151. self.z1 = Parameter(Tensor(np.ones([32])*2, dtype=mstype.float32), name='z1')
  152. def construct(self, x, y):
  153. a = x * self.z1
  154. b = y * y
  155. return a * b
  156. class ForwardNetAdd(nn.Cell):
  157. def __init__(self, net):
  158. super().__init__()
  159. self.net = net
  160. self.z2 = Parameter(Tensor(np.ones([32]), dtype=mstype.float32), name='z2')
  161. self.z3 = Parameter(Tensor(np.ones([32]), dtype=mstype.float32), name='z2')
  162. def construct(self, x, y):
  163. a = x + x*self.z3
  164. b = y + y*self.z2
  165. res = self.net(a, b)
  166. return res
  167. mulnet = ForwardNetMul()
  168. addnet = ForwardNetAdd(mulnet)
  169. x = Tensor(np.ones([32]), dtype=mstype.float32)
  170. y = Tensor(np.ones([32]), dtype=mstype.float32)
  171. sens = Tensor(np.ones([32]), dtype=mstype.float32)
  172. mulnet.set_grad()
  173. addnet.set_grad()
  174. out1 = mulnet(x, y)
  175. out2 = addnet(x, y)
  176. grad_mul = GradOfAllInputs(addnet)
  177. grad_add = GradOfFirstInput(mulnet)
  178. grad_mul(x, y, sens)
  179. grad_add(x, y, sens)