You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cell_bprop.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_cell_bprop """
  16. import numpy as np
  17. import mindspore.nn as nn
  18. from mindspore.ops import composite as C
  19. from mindspore.ops import operations as P
  20. from mindspore import Parameter
  21. from mindspore.common.tensor import Tensor
  22. import mindspore.common.dtype as mstype
  23. from mindspore.common.initializer import initializer
  24. from mindspore import context
  25. from ....mindspore_test_framework.utils.bprop_util import bprop
  26. import pytest
  27. def setup_module(module):
  28. context.set_context(mode=context.PYNATIVE_MODE)
  29. class MulAdd(nn.Cell):
  30. def __init__(self):
  31. super(MulAdd, self).__init__()
  32. def construct(self, x, y):
  33. return 2 * x + y
  34. def bprop(self, x, y, out, dout):
  35. # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result
  36. return 2 * dout, 2 * y
  37. def test_grad_mul_add():
  38. mul_add = MulAdd()
  39. assert C.grad_all(mul_add)(1, 2) == (2, 4)
  40. class InlineMulADD(nn.Cell):
  41. def __init__(self):
  42. super(InlineMulADD, self).__init__()
  43. self.mul_add = MulAdd()
  44. self.param = 2
  45. def construct(self, x, y):
  46. return self.mul_add(x, y) + x + self.param * y
  47. def test_grad_inline_mul_add():
  48. inline_mul_add = InlineMulADD()
  49. assert C.grad_all(inline_mul_add)(1, 2) == (3, 6)
  50. class WithParameter(nn.Cell):
  51. def __init__(self):
  52. super(WithParameter, self).__init__()
  53. self.param1 = Parameter(1, 'param1')
  54. self.param2 = Parameter(2, 'param2')
  55. def construct(self, x, y):
  56. return self.param1 * self.param2 * x + y
  57. def bprop(self, x, y, out, dout):
  58. # In this test case, The user defined bprop is wrong defined purposely to distinguish from ad result
  59. return self.param1 * self.param2 * dout, 2 * y
  60. def test_with_param():
  61. with_param = WithParameter()
  62. with pytest.raises(RuntimeError):
  63. C.grad_all(with_param)(1, 2)
  64. class WithNoBprop(nn.Cell):
  65. def __init__(self):
  66. super(WithNoBprop, self).__init__()
  67. def construct(self, x, y):
  68. return 2 * x + y
  69. def test_with_no_bprop():
  70. with_no_bprop = WithNoBprop()
  71. C.grad_all(with_no_bprop)(1, 2) == (2, 1)
  72. def test_grad_in_bprop_1():
  73. class GradInBprop_1(nn.Cell):
  74. def __init__(self):
  75. super(GradInBprop_1, self).__init__()
  76. self.relu = P.ReLU()
  77. def construct(self, x, y):
  78. return self.relu(x)
  79. class GradInBprop_2(nn.Cell):
  80. def __init__(self):
  81. super(GradInBprop_2, self).__init__()
  82. self.f = GradInBprop_1()
  83. def construct(self, x, y):
  84. return self.f(x, y), C.grad_all(self.f)(x, y)
  85. def bprop(self, x, y, out, dout):
  86. grads = C.grad_all(self.f)(x, y)
  87. return out[1][0], grads[1]
  88. class GradInBprop_3(nn.Cell):
  89. def __init__(self):
  90. super(GradInBprop_3, self).__init__()
  91. self.f = GradInBprop_2()
  92. def construct(self, x, y):
  93. return self.f(x, y)
  94. grad_in_bprop = GradInBprop_3()
  95. grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)),
  96. Tensor(np.ones([2, 2]).astype(np.float32)))
  97. assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
  98. assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all()
  99. def test_grad_in_bprop_2():
  100. class GradInBprop_1(nn.Cell):
  101. def __init__(self):
  102. super(GradInBprop_1, self).__init__()
  103. self.relu = P.ReLU()
  104. def construct(self, x, y):
  105. return self.relu(x)
  106. def bprop(self, x, y, out, dout):
  107. return x * y, y + x
  108. class GradInBprop_2(nn.Cell):
  109. def __init__(self):
  110. super(GradInBprop_2, self).__init__()
  111. self.f = GradInBprop_1()
  112. def construct(self, x, y):
  113. return self.f(x, y), C.grad_all(self.f)(x, y)
  114. def bprop(self, x, y, out, dout):
  115. grads = C.grad_all(self.f)(x, y)
  116. return out[1][0], grads[1]
  117. class GradInBprop_3(nn.Cell):
  118. def __init__(self):
  119. super(GradInBprop_3, self).__init__()
  120. self.f = GradInBprop_2()
  121. def construct(self, x, y):
  122. return self.f(x, y)
  123. grad_in_bprop = GradInBprop_3()
  124. grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)),
  125. Tensor(np.ones([2, 2]).astype(np.float32)))
  126. assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
  127. assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all()
  128. def test_grad_in_bprop_3():
  129. class GradInBprop_1(nn.Cell):
  130. def __init__(self):
  131. super(GradInBprop_1, self).__init__()
  132. self.relu = P.ReLU()
  133. def construct(self, x, y):
  134. return self.relu(x)
  135. class GradInBprop_2(nn.Cell):
  136. def __init__(self):
  137. super(GradInBprop_2, self).__init__()
  138. self.f = GradInBprop_1()
  139. def construct(self, x, y):
  140. return self.f(x, y), C.grad_all(self.f)(x, y)
  141. def bprop(self, x, y, out, dout):
  142. grads = C.grad_all(self.f)(x, y)
  143. return out[1][0], grads[1]
  144. class GradInBprop_3(nn.Cell):
  145. def __init__(self):
  146. super(GradInBprop_3, self).__init__()
  147. self.f = GradInBprop_2()
  148. def construct(self, x, y):
  149. return self.f(x, y)
  150. def bprop(self, x, y, out, dout):
  151. return x + y + y + out[0], x + x + y + y + dout[0]
  152. grad_in_bprop = GradInBprop_3()
  153. grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)),
  154. Tensor(np.ones([2, 2]).astype(np.float32)))
  155. assert (grads[0].asnumpy() == np.array([[4, 4], [4, 4]]).astype(np.float32)).all()
  156. assert (grads[1].asnumpy() == np.array([[5, 5], [5, 5]]).astype(np.float32)).all()
  157. class OneInputBprop(nn.Cell):
  158. def __init__(self):
  159. super().__init__()
  160. self.op = P.ReLU()
  161. def construct(self, x):
  162. return self.op(x)
  163. def bprop(self, x, out, dout):
  164. return 5 * x,
  165. def test_grad_one_input_bprop():
  166. net = OneInputBprop()
  167. input = Tensor(np.ones([2, 2]).astype(np.float32))
  168. grad = C.grad_all(net)(input)
  169. assert (grad[0].asnumpy() == np.array([5, 5]).astype(np.float32)).all()
  170. class TwoInput(nn.Cell):
  171. def __init__(self):
  172. super().__init__()
  173. def construct(self, x, y):
  174. return x * y
  175. class InlineBpropTwoInput(nn.Cell):
  176. def __init__(self):
  177. super().__init__()
  178. self.f = TwoInput()
  179. def construct(self, x, y):
  180. return self.f(x, y), C.grad_all(self.f)(x, y)
  181. def bprop(self, x, y, out, dout):
  182. grads = C.grad_all(self.f)(x, y)
  183. return grads[0] * 2, grads[1] * 2
  184. def test_grad_inline_bprop_two_input():
  185. net = InlineBpropTwoInput()
  186. input1 = Tensor(np.ones([2, 2]).astype(np.float32))
  187. input2 = Tensor(np.ones([2, 2]).astype(np.float32))
  188. grads = C.grad_all(net)(input1, input2)
  189. assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
  190. assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
  191. assert (len(grads) == 2)
  192. class TwoInputBprop(nn.Cell):
  193. def __init__(self):
  194. super().__init__()
  195. self.op = P.Mul()
  196. def construct(self, x, y):
  197. return self.op(x, y)
  198. def bprop(self, x, y, out, dout):
  199. return 5 * x, 8 * y
  200. class TwoInput(nn.Cell):
  201. def __init__(self):
  202. super().__init__()
  203. self.op = P.Mul()
  204. def construct(self, x, y):
  205. return self.op(x, y)
  206. class TwoInputWithParameter(nn.Cell):
  207. def __init__(self):
  208. super().__init__()
  209. self.op = P.Mul()
  210. self.inputdata = Parameter(initializer(1, (2,2), mstype.float32),name="global_step")
  211. def construct(self, x, y):
  212. x = self.inputdata + x
  213. return self.op(x, y)
  214. class TwoInputWithOnlyInitParameterBprop(nn.Cell):
  215. def __init__(self):
  216. super().__init__()
  217. self.op = P.Mul()
  218. self.inputdata = Parameter(initializer(1, (2,2), mstype.float32),name="global_step")
  219. def construct(self, x, y):
  220. return self.op(x, y)
  221. def bprop(self, x, y, out, dout):
  222. return 5*x, 8*y
  223. class InlineMutilTwoInputParameterCell(nn.Cell):
  224. def __init__(self):
  225. super().__init__()
  226. self.f1 = TwoInputBprop()
  227. self.f2 = TwoInput()
  228. self.f3 = TwoInputWithParameter()
  229. self.f4 = TwoInputWithOnlyInitParameterBprop()
  230. def construct(self, x, y):
  231. output = self.f1(x,y)+self.f2(x,y)+self.f3(x,y)+self.f4(x,y)
  232. return output
  233. def test_grad_inline_bprop_multi_input():
  234. net = InlineMutilTwoInputParameterCell()
  235. input1 = Tensor(np.ones([2, 2]).astype(np.float32))
  236. input2 = Tensor(np.ones([2, 2]).astype(np.float32))
  237. grads = C.grad_all(net)(input1, input2)
  238. assert (grads[0].asnumpy() == np.array([[12, 12], [12, 12]]).astype(np.float32)).all()
  239. assert (grads[1].asnumpy() == np.array([[19, 19], [19, 19]]).astype(np.float32)).all()
  240. assert (len(grads) == 2)
  241. class MulAddWithParam(nn.Cell):
  242. def __init__(self):
  243. super(MulAddWithParam, self).__init__()
  244. self.mul_add = MulAdd()
  245. self.param = Parameter(Tensor(np.array([[3, 2]], np.float32)), 'param')
  246. def construct(self, x):
  247. return self.mul_add(self.param, x)
  248. def test_refkey_bprop():
  249. net = MulAddWithParam()
  250. input_data = Tensor(np.array([2, 2], np.float32))
  251. grads = bprop(net, input_data,
  252. grads_wrt_outputs=(Tensor(np.ones([1, 2]).astype(np.float32))),
  253. wrt=['params', 'inputs'],
  254. params=net.trainable_params())
  255. assert (grads[0][0].asnumpy() == np.array([4, 4]).astype(np.float32)).all()
  256. assert (grads[1][0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
  257. class MulAddWithWrongOutputNum(nn.Cell):
  258. def __init__(self):
  259. super(MulAddWithWrongOutputNum, self).__init__()
  260. def construct(self, x, y):
  261. return 2 * x + y
  262. def bprop(self, x, y, out, dout):
  263. return 2 * dout, 2 * y, out
  264. def test_grad_mul_add_with_wrong_output_num():
  265. mul_add = MulAddWithWrongOutputNum()
  266. C.grad_all(mul_add)(1, 2)