You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_stop_gradient.py 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_stop_gradient """
  16. import numpy as np
  17. import pytest
  18. import mindspore as ms
  19. import mindspore.common.dtype as mstype
  20. import mindspore.nn as nn
  21. from mindspore import Parameter, ParameterTuple
  22. from mindspore import Tensor
  23. from mindspore import context
  24. from mindspore.common.api import ms_function
  25. from mindspore.ops import composite as C
  26. from mindspore.ops import operations as P
  27. from mindspore.ops.functional import stop_gradient
  28. from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
  29. from ..ut_filter import non_graph_engine
  30. from ....mindspore_test_framework.utils.bprop_util import bprop
  31. def setup_module(module):
  32. context.set_context(mode=context.PYNATIVE_MODE)
  33. def stop_func(x, y):
  34. """ stop_func"""
  35. c = x * y
  36. c_s = x + y
  37. return c_s, c
  38. def stop_test1(x, y):
  39. """ stop_test1 """
  40. c = x * y
  41. c_s = stop_gradient(c)
  42. return c_s
  43. def stop_test2(x, y):
  44. """ stop_test2 """
  45. c = x * y
  46. c_s = stop_gradient(c)
  47. d = c_s + x * y
  48. return d * y
  49. def stop_test3(x, y):
  50. """ stop_test3 """
  51. x = x * y
  52. z = stop_test1(x, y)
  53. k = z * y
  54. return k
  55. def stop_test5(x, y):
  56. """ stop_test3 """
  57. x = x + y
  58. o1, o2 = stop_func(x, y)
  59. c = stop_gradient(o1)
  60. c = o2 + c
  61. return c
  62. def stop_test4(x, y):
  63. """ stop_test4 """
  64. c = x + y
  65. c_s = stop_gradient(c)
  66. e = c + c_s
  67. return e
  68. @ms_function
  69. def grad_stop_test(x, y):
  70. """ grad_stop_test """
  71. return C.grad_all(stop_test2)(x, y)
  72. @ms_function
  73. def grad_stop_test1(x, y):
  74. """ grad_stop_test1 """
  75. return C.grad_all(stop_test3)(x, y)
  76. @ms_function
  77. def grad_stop_test5(x, y):
  78. """ grad_stop_test5 """
  79. return C.grad_all(stop_test5)(x, y)
  80. def test_stop():
  81. """ test_stop """
  82. print("test_stop:", grad_stop_test(1, 1))
  83. def test_stop1():
  84. """ test_stop1 """
  85. print("test_stop1:", grad_stop_test1(2, 3))
  86. def test_stop5():
  87. """ test_stop1 """
  88. print("test_stop5:", grad_stop_test5(2, 3))
  89. class GradWrap(nn.Cell):
  90. """ GradWrap definition """
  91. def __init__(self, network):
  92. super(GradWrap, self).__init__()
  93. self.network = network
  94. self.weights = ParameterTuple(network.get_parameters())
  95. @ms_function
  96. def construct(self, x, label):
  97. weights = self.weights
  98. return C.grad_by_list(self.network, weights)(x, label)
  99. @non_graph_engine
  100. def test_softmaxloss_grad():
  101. """ test_softmaxloss_grad """
  102. class NetWithLossClass(nn.Cell):
  103. """ NetWithLossClass definition """
  104. def __init__(self, network):
  105. super(NetWithLossClass, self).__init__()
  106. self.loss = nn.SoftmaxCrossEntropyWithLogits()
  107. self.network = network
  108. @ms_function
  109. def construct(self, x, label):
  110. predict = self.network(x)
  111. return self.loss(predict, label)
  112. class Net(nn.Cell):
  113. """ Net definition """
  114. def __init__(self):
  115. super(Net, self).__init__()
  116. self.weight = Parameter(Tensor(np.ones([64, 10])), name="weight")
  117. self.bias = Parameter(Tensor(np.ones([10]).astype(np.float32)), name="bias")
  118. self.fc = P.MatMul()
  119. self.fc2 = nn.Dense(10, 10)
  120. self.biasAdd = P.BiasAdd()
  121. self.relu = nn.ReLU()
  122. self.cast = P.Cast()
  123. @ms_function
  124. def construct(self, x):
  125. x = self.fc(x, self.weight)
  126. x = self.cast(x, mstype.float32)
  127. x = self.relu(self.fc2(x))
  128. x = self.fc2(x)
  129. x = stop_gradient(x)
  130. x = self.biasAdd(x, self.bias)
  131. return x
  132. net = GradWrap(NetWithLossClass(Net()))
  133. predict = Tensor(np.ones([1, 64]))
  134. label = Tensor(np.zeros([1, 10]).astype(np.float32))
  135. print("pynative run")
  136. out = net(predict, label)
  137. print("out:", out)
  138. def test_stop_gradient_1():
  139. class Mul(nn.Cell):
  140. def __init__(self):
  141. super(Mul, self).__init__()
  142. @ms_function
  143. def construct(self, x, y):
  144. ret = x * y
  145. ret = stop_gradient(ret)
  146. return ret
  147. dx, dy = bprop(Mul(), Tensor(np.ones([2, 2]).astype(np.float32)),
  148. Tensor(np.ones([2, 2]).astype(np.float32)), wrt=['inputs'])
  149. expect = np.zeros([2, 2])
  150. assert (dx.asnumpy() == expect).all()
  151. assert (dy.asnumpy() == expect).all()
  152. def test_stop_gradient_2():
  153. class Mul(nn.Cell):
  154. def __init__(self):
  155. super(Mul, self).__init__()
  156. @ms_function
  157. def construct(self, x, y):
  158. c = x * y
  159. z = x * y
  160. return c, z
  161. class MulAdd(nn.Cell):
  162. def __init__(self):
  163. super(MulAdd, self).__init__()
  164. self.mul = Mul()
  165. @ms_function
  166. def construct(self, x, y):
  167. u = x + y
  168. v = x - y
  169. c, z = self.mul(u, v)
  170. c = stop_gradient(c)
  171. ret1 = c + x + y
  172. ret2 = z + y + y
  173. return ret1, ret2
  174. dx = bprop(MulAdd(), Tensor(np.ones([2, 2]).astype(np.float32)),
  175. Tensor(np.ones([2, 2]).astype(np.float32)))
  176. expect = np.array([[3.0, 3.0], [3.0, 3.0]])
  177. assert (dx.asnumpy() == expect).all()
  178. def test_stop_gradient_3():
  179. class TupleGetItem(nn.Cell):
  180. def __init__(self):
  181. super(TupleGetItem, self).__init__()
  182. @ms_function
  183. def construct(self, x1, x2, x3, x4, x5):
  184. z1 = x1 + x1
  185. z2 = x1 * x2
  186. t = (z1, z2, x3, x4, x5)
  187. z2 = t[1]
  188. z2 = stop_gradient(z2)
  189. return z1, z2, x3, x4, x5
  190. dx = bprop(TupleGetItem(),
  191. Tensor(np.ones([2]).astype(np.float32)),
  192. Tensor(np.ones([2]).astype(np.float32)),
  193. Tensor(np.ones([2]).astype(np.float32)),
  194. Tensor(np.ones([2]).astype(np.float32)),
  195. Tensor(np.ones([2]).astype(np.float32)))
  196. expect = np.array([[2.0, 2.0], [2.0, 2.0]])
  197. assert (dx.asnumpy() == expect).all()
  198. def test_stop_gradient_4():
  199. def stop_test(x):
  200. return stop_gradient(x)
  201. assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,)
  202. def test_stop_gradient_5():
  203. def stop_test(x):
  204. y = x + x
  205. y = stop_gradient(y)
  206. ret = x + y
  207. return ret
  208. assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,)
  209. def test_stop_gradient_6():
  210. def stop_test(x, y):
  211. ret = x * y
  212. ret = stop_gradient(ret)
  213. return ret
  214. assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (0, 0)
  215. class PrimWithMultiOutputs(PrimitiveWithInfer):
  216. @prim_attr_register
  217. def __init__(self):
  218. """init"""
  219. def __call__(self, x, y):
  220. """Implement by vm mode."""
  221. return x, y
  222. def infer_shape(self, x_shape, y_shape):
  223. return x_shape, y_shape
  224. def infer_dtype(self, x_type, y_type):
  225. return x_type, y_type
  226. def get_bprop(self):
  227. def bprop(x, y, out, dout):
  228. return (dout[0], dout[1])
  229. return bprop
  230. def test_stop_gradient_7():
  231. class PrimWithMultiOutputs_(nn.Cell):
  232. def __init__(self):
  233. super(PrimWithMultiOutputs_, self).__init__()
  234. self.prim_with_multi_outputs = PrimWithMultiOutputs()
  235. @ms_function
  236. def construct(self, x1, x2):
  237. x1, x2 = self.prim_with_multi_outputs(x1, x2)
  238. x1 = stop_gradient(x1)
  239. return x1, x2
  240. dx, dy = bprop(PrimWithMultiOutputs_(), Tensor(np.ones([2]).astype(np.float32)),
  241. Tensor(np.ones([2]).astype(np.float32)), wrt=['inputs'])
  242. expect_dx = np.zeros([2])
  243. expect_dy = np.ones([2])
  244. assert (dx.asnumpy() == expect_dx).all()
  245. assert (dy.asnumpy() == expect_dy).all()
  246. def test_stop_gradient_8():
  247. class PrimWithMultiOutputs_(nn.Cell):
  248. def __init__(self):
  249. super(PrimWithMultiOutputs_, self).__init__()
  250. self.prim_with_multi_output = PrimWithMultiOutputs()
  251. @ms_function
  252. def construct(self, x1, x2):
  253. x1, x2 = stop_gradient(self.prim_with_multi_output(x1, x2))
  254. return x1, x2
  255. dx, dy = bprop(PrimWithMultiOutputs_(), Tensor(np.ones([2]).astype(np.float32)),
  256. Tensor(np.ones([2]).astype(np.float32)), wrt=['inputs'])
  257. expect_dx = np.zeros([2])
  258. expect_dy = np.zeros([2])
  259. assert (dx.asnumpy() == expect_dx).all()
  260. assert (dy.asnumpy() == expect_dy).all()
  261. def test_stop_gradient_9():
  262. class Mul(nn.Cell):
  263. def __init__(self):
  264. super(Mul, self).__init__()
  265. @ms_function
  266. def construct(self, x, y):
  267. c = x * y
  268. z = x * y
  269. return c, z
  270. class MulAdd(nn.Cell):
  271. def __init__(self):
  272. super(MulAdd, self).__init__()
  273. self.mul = Mul()
  274. @ms_function
  275. def construct(self, x, y):
  276. u = x + y
  277. v = x - y
  278. c, z = self.mul(u, v)
  279. c1 = stop_gradient(c)
  280. c2 = c
  281. ret1 = c1 + x + y + c2
  282. ret2 = z + y + y
  283. return ret1, ret2
  284. dx = bprop(MulAdd(), Tensor(np.ones([2, 2]).astype(np.float32)),
  285. Tensor(np.ones([2, 2]).astype(np.float32)))
  286. expect = np.array([[5.0, 5.0], [5.0, 5.0]])
  287. assert (dx.asnumpy() == expect).all()
  288. class PrimWithNoBprop(PrimitiveWithInfer):
  289. @prim_attr_register
  290. def __init__(self):
  291. """init"""
  292. def __call__(self, x, y):
  293. """Implement by vm mode."""
  294. return x, y
  295. def infer_shape(self, x_shape, y_shape):
  296. return x_shape, y_shape
  297. def infer_dtype(self, x_type, y_type):
  298. return x_type, y_type
  299. def test_stop_gradient_10():
  300. class PrimWithNoBprop_(nn.Cell):
  301. def __init__(self):
  302. super(PrimWithNoBprop_, self).__init__()
  303. self.prim_with_no_bprop = PrimWithNoBprop()
  304. @ms_function
  305. def construct(self, x, y):
  306. x = x * y
  307. x, y = self.prim_with_no_bprop(x, y)
  308. x = stop_gradient(x)
  309. y = stop_gradient(y)
  310. return x, y
  311. dx = bprop(PrimWithNoBprop_(), Tensor(np.ones([2]).astype(np.float32)),
  312. Tensor(np.ones([2]).astype(np.float32)))
  313. expect_dx = np.zeros([2])
  314. assert (dx.asnumpy() == expect_dx).all()
  315. def test_stop_gradient_11():
  316. class PrimWithNoBprop_(nn.Cell):
  317. def __init__(self):
  318. super(PrimWithNoBprop_, self).__init__()
  319. self.prim_with_no_bprop = PrimWithNoBprop()
  320. @ms_function
  321. def construct(self, x, y):
  322. x, y = self.prim_with_no_bprop(x, y)
  323. x = stop_gradient(x)
  324. return x, y
  325. with pytest.raises(RuntimeError):
  326. bprop(PrimWithNoBprop_(), Tensor(np.ones([2]).astype(np.float32)),
  327. Tensor(np.ones([2]).astype(np.float32)))
  328. def test_stop_print():
  329. class StopPrint(nn.Cell):
  330. def __init__(self):
  331. super(StopPrint, self).__init__()
  332. self.printm = P.Print()
  333. def construct(self, x, y):
  334. self.printm("StopPrint", x)
  335. self.printm(y)
  336. return x, y
  337. C.grad_all(StopPrint())(Tensor(np.ones([2]).astype(np.float32)),
  338. Tensor(np.ones([2]).astype(np.float32)))