You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_var_grad.py 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. from mindspore import context
  17. from mindspore import Tensor, Parameter
  18. from mindspore.nn import Cell
  19. from mindspore.ops import operations as P
  20. import mindspore.ops.composite as C
  21. from mindspore.common.api import _executor
  22. from mindspore.common.parameter import ParameterTuple
  23. from mindspore.common import dtype as mstype
  24. context.set_context(mode=context.GRAPH_MODE)
  25. def test_net_vargs_expand():
  26. class AddNet(Cell):
  27. def __init__(self):
  28. super(AddNet, self).__init__()
  29. self.w = Parameter(Tensor(np.ones((3, 4, 5), np.float32)), "w2", requires_grad=True)
  30. def construct(self, x, y):
  31. return x + y
  32. x = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
  33. y = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
  34. sens = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
  35. net = AddNet()
  36. out = C.grad_all_with_sens(net, net.trainable_params())(x, y, sens)
  37. class VarNet(Cell):
  38. def __init__(self, net):
  39. super(VarNet, self).__init__()
  40. self.b = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b", requires_grad=True)
  41. self.w = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "w", requires_grad=True)
  42. self.net = net
  43. def construct(self, *args):
  44. return self.net(*args)*self.w + self.b
  45. class SecondNet(Cell):
  46. def __init__(self):
  47. super(SecondNet, self).__init__()
  48. self.b2 = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b2", requires_grad=True)
  49. def construct(self, *args):
  50. res = args[0] + args[1]
  51. return res + self.b2
  52. def test_all_var_args_grad_with_sens():
  53. """"test grad_by_list_with_sens with all var args input"""
  54. class GradNet(Cell):
  55. def __init__(self, net):
  56. super(GradNet, self).__init__()
  57. self.weights = ParameterTuple(net.trainable_params())
  58. self.net = net
  59. def construct(self, *inputs):
  60. return C.grad_by_list_with_sens(self.net, self.weights)(*inputs)
  61. x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
  62. y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
  63. sens = Tensor(1.0, dtype=mstype.float32)
  64. net = VarNet(SecondNet())
  65. grad_net = GradNet(net)
  66. out = grad_net(x, y, sens)
  67. def test_grad_list_var_args():
  68. class GradNet(Cell):
  69. def __init__(self, net):
  70. super(GradNet, self).__init__()
  71. self.weights = ParameterTuple(net.trainable_params())
  72. self.net = net
  73. def construct(self, *inputs):
  74. return C.grad_by_list(self.net, self.weights)(*inputs)
  75. x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
  76. y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
  77. net = VarNet(SecondNet())
  78. grad_net = GradNet(net)
  79. out = grad_net(x, y)
  80. def test_grad_all_var_args():
  81. class GradNet(Cell):
  82. def __init__(self, net):
  83. super(GradNet, self).__init__()
  84. self.weights = ParameterTuple(net.trainable_params())
  85. self.net = net
  86. def construct(self, *inputs):
  87. return C.grad_all(self.net)(*inputs)
  88. x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
  89. y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
  90. net = VarNet(SecondNet())
  91. grad_net = GradNet(net)
  92. out = grad_net(x, y)
  93. def test_grad_all_var_args_with_sens():
  94. class GradNet(Cell):
  95. def __init__(self, net):
  96. super(GradNet, self).__init__()
  97. self.weights = ParameterTuple(net.trainable_params())
  98. self.net = net
  99. def construct(self, *inputs):
  100. return C.grad_all_with_sens(self.net)(*inputs)
  101. x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
  102. y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
  103. sens = Tensor(1.0, dtype=mstype.float32)
  104. net = VarNet(SecondNet())
  105. grad_net = GradNet(net)
  106. out = grad_net(x, y, sens)
  107. def test_grad_var_args_with_sens():
  108. class GradNet(Cell):
  109. def __init__(self, net):
  110. super(GradNet, self).__init__()
  111. self.weights = ParameterTuple(net.trainable_params())
  112. self.net = net
  113. def construct(self, *inputs):
  114. return C.grad_with_sens(self.net)(*inputs)
  115. x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
  116. y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
  117. sens = Tensor(1.0, dtype=mstype.float32)
  118. net = VarNet(SecondNet())
  119. grad_net = GradNet(net)
  120. out = grad_net(x, y, sens)
  121. def test_var_args_grad():
  122. class VarNet(Cell):
  123. def __init__(self, net):
  124. super(VarNet, self).__init__()
  125. self.b = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b", requires_grad=True)
  126. self.net = net
  127. def construct(self, *args):
  128. return self.net(*args) + self.b
  129. class SecondNet(Cell):
  130. def __init__(self):
  131. super(SecondNet, self).__init__()
  132. self.b2 = Parameter(Tensor(np.ones([3, 4, 5]), dtype=mstype.float32), "b2", requires_grad=True)
  133. def construct(self, *args):
  134. res = args[0] + args[1]
  135. return res + self.b2
  136. class GradNet(Cell):
  137. def __init__(self, net):
  138. super(GradNet, self).__init__()
  139. self.net = net
  140. self.weights = ParameterTuple(net.trainable_params())
  141. def construct(self, x, y, sens):
  142. return C.grad_by_list_with_sens(self.net, self.weights)(x, y, sens)
  143. x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
  144. y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
  145. sens = Tensor(1.0, dtype=mstype.float32)
  146. net = VarNet(SecondNet())
  147. grad_net = GradNet(net)
  148. out = grad_net(x, y, sens)
  149. def test_var_args_positional():
  150. """"test grad_all with var args in inner graph"""
  151. class VarNet(Cell):
  152. def __init__(self, net):
  153. super(VarNet, self).__init__()
  154. self.net = net
  155. def construct(self, x, y):
  156. return self.net(x, y)*x
  157. class SecondNet(Cell):
  158. def __init__(self):
  159. super(SecondNet, self).__init__()
  160. def construct(self, *args):
  161. return args[0] + args[1]
  162. class GradNet(Cell):
  163. def __init__(self, net):
  164. super(GradNet, self).__init__()
  165. self.net = net
  166. self.weights = ParameterTuple(net.trainable_params())
  167. def construct(self, x, y):
  168. return C.grad_all(self.net)(x, y)
  169. x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
  170. y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
  171. net = VarNet(SecondNet())
  172. grad_net = GradNet(net)
  173. out = grad_net(x, y)