You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_multi_grad.py 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. from mindspore import context, nn, Tensor, Parameter, ParameterTuple
  17. from mindspore.common import dtype as mstype
  18. from mindspore.ops import composite as C
  19. def setup_module():
  20. context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False)
  21. class _Grad(nn.Cell):
  22. def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
  23. super().__init__()
  24. self.network = network
  25. self.grad = grad
  26. self.sens_param = self.grad.sens_param
  27. self.wrt_params = wrt_params
  28. self.real_inputs_count = real_inputs_count
  29. if self.wrt_params:
  30. self.params = ParameterTuple(self.network.trainable_params())
  31. def construct(self, *inputs):
  32. if self.wrt_params:
  33. if self.real_inputs_count is None or self.sens_param is False:
  34. return self.grad(self.network, self.params)(*inputs)
  35. real_inputs = inputs[:self.real_inputs_count]
  36. sense_param_inputs = inputs[self.real_inputs_count:]
  37. return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
  38. if self.real_inputs_count is None or self.sens_param is False:
  39. return self.grad(self.network)(*inputs)
  40. real_inputs = inputs[:self.real_inputs_count]
  41. sense_param_inputs = inputs[self.real_inputs_count:]
  42. return self.grad(self.network)(*real_inputs, sense_param_inputs)
  43. class GradOfFirstInput(_Grad):
  44. """
  45. get grad of first input
  46. """
  47. def __init__(self, network, sens_param=True, real_inputs_count=None):
  48. super().__init__(grad=C.GradOperation(sens_param=sens_param),
  49. network=network, real_inputs_count=real_inputs_count)
  50. class GradOfAllInputs(_Grad):
  51. """
  52. get grad of first input
  53. """
  54. def __init__(self, network, sens_param=True, real_inputs_count=None):
  55. super().__init__(grad=C.GradOperation(get_all=True, sens_param=sens_param),
  56. network=network, real_inputs_count=real_inputs_count)
  57. def test_multi_grad():
  58. class ForwardNetMul(nn.Cell):
  59. def __init__(self):
  60. super().__init__()
  61. def construct(self, x, y):
  62. a = x * x
  63. b = y * y
  64. return a * b
  65. class ForwardNetAdd(nn.Cell):
  66. def __init__(self):
  67. super().__init__()
  68. def construct(self, x, y):
  69. a = x + x + x
  70. b = y + y
  71. return a * b
  72. mulnet = ForwardNetMul()
  73. addnet = ForwardNetAdd()
  74. x = Tensor(np.ones([32]), dtype=mstype.float32)
  75. y = Tensor(np.ones([32])*2, dtype=mstype.float32)
  76. sens = Tensor(np.ones([32]), dtype=mstype.float32)
  77. mulnet.set_grad()
  78. addnet.set_grad()
  79. out1 = mulnet(x, y)
  80. out2 = addnet(x, y)
  81. grad_mul = GradOfAllInputs(mulnet)
  82. grad_add = GradOfAllInputs(addnet)
  83. grad_mul(x, y, sens)
  84. grad_add(x, y, sens)
  85. def test_multi_same_grad():
  86. class ForwardNetMul(nn.Cell):
  87. def __init__(self):
  88. super().__init__()
  89. def construct(self, x, y):
  90. a = x * x
  91. b = y * y
  92. return a * b
  93. class ForwardNetAdd(nn.Cell):
  94. def __init__(self):
  95. super().__init__()
  96. def construct(self, x, y):
  97. a = x*3
  98. b = y*2
  99. return a + b
  100. mulnet = ForwardNetMul()
  101. addnet = ForwardNetAdd()
  102. x = Tensor(np.ones([32]), dtype=mstype.float32)
  103. y = Tensor(np.ones([32]), dtype=mstype.float32)
  104. sens = Tensor(np.ones([32]), dtype=mstype.float32)
  105. mulnet.set_grad()
  106. addnet.set_grad()
  107. out1 = mulnet(x, y)
  108. out2 = addnet(x, y)
  109. grad_mul = GradOfAllInputs(mulnet)
  110. grad_add = GradOfFirstInput(mulnet)
  111. grad_mul(x, y, sens)
  112. grad_add(x, y, sens)
  113. def test_net_inner_grad():
  114. class ForwardNetMul(nn.Cell):
  115. def __init__(self):
  116. super().__init__()
  117. def construct(self, x, y):
  118. a = x * x
  119. b = y * y
  120. return a * b
  121. class ForwardNetAdd(nn.Cell):
  122. def __init__(self, net):
  123. super().__init__()
  124. self.net = net
  125. def construct(self, x, y):
  126. a = x + x
  127. b = y + y
  128. res = self.net(a, b)
  129. return res
  130. mulnet = ForwardNetMul()
  131. addnet = ForwardNetAdd(mulnet)
  132. x = Tensor(np.ones([32]), dtype=mstype.float32)
  133. y = Tensor(np.ones([32]), dtype=mstype.float32)
  134. sens = Tensor(np.ones([32]), dtype=mstype.float32)
  135. mulnet.set_grad()
  136. addnet.set_grad()
  137. out1 = mulnet(x, y)
  138. out2 = addnet(x, y)
  139. grad_mul = GradOfAllInputs(addnet)
  140. grad_add = GradOfAllInputs(mulnet)
  141. grad_mul(x, y, sens)
  142. grad_add(x, y, sens)
  143. def test_net_inner_first_run_grad():
  144. class ForwardNetMul(nn.Cell):
  145. def __init__(self):
  146. super().__init__()
  147. self.z1 = Parameter(Tensor(np.ones([32])*2, dtype=mstype.float32), name='z1')
  148. def construct(self, x, y):
  149. a = x * self.z1
  150. b = y * y
  151. return a * b
  152. class ForwardNetAdd(nn.Cell):
  153. def __init__(self, net):
  154. super().__init__()
  155. self.net = net
  156. self.z2 = Parameter(Tensor(np.ones([32]), dtype=mstype.float32), name='z2')
  157. self.z3 = Parameter(Tensor(np.ones([32]), dtype=mstype.float32), name='z2')
  158. def construct(self, x, y):
  159. a = x + x*self.z3
  160. b = y + y*self.z2
  161. res = self.net(a, b)
  162. return res
  163. mulnet = ForwardNetMul()
  164. addnet = ForwardNetAdd(mulnet)
  165. x = Tensor(np.ones([32]), dtype=mstype.float32)
  166. y = Tensor(np.ones([32]), dtype=mstype.float32)
  167. sens = Tensor(np.ones([32]), dtype=mstype.float32)
  168. mulnet.set_grad()
  169. addnet.set_grad()
  170. out1 = mulnet(x, y)
  171. out2 = addnet(x, y)
  172. grad_mul = GradOfAllInputs(addnet)
  173. grad_add = GradOfFirstInput(mulnet)
  174. grad_mul(x, y, sens)
  175. grad_add(x, y, sens)