You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_auto_monad.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore as ms
  18. import mindspore.ops.composite as C
  19. from mindspore import context
  20. import mindspore.nn as nn
  21. from mindspore.ops import operations as P
  22. from mindspore.ops import functional as F
  23. from mindspore import Tensor
  24. from mindspore.common.parameter import Parameter, ParameterTuple
  25. from tests.security_utils import security_off_wrap
  26. grad_all_list = C.GradOperation(get_all=True, get_by_list=True)
  27. grad_by_list = C.GradOperation(get_by_list=True)
  28. context.set_context(mode=context.GRAPH_MODE)
  29. def test_load_grad():
  30. class LoadNet(nn.Cell):
  31. def __init__(self):
  32. super().__init__()
  33. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  34. def construct(self, x, y):
  35. x = x * y * self.z
  36. return x
  37. x = Tensor(np.array([2.0], np.float32))
  38. y = Tensor(np.array([3.0], np.float32))
  39. load_net = LoadNet()
  40. grad_net = grad_all_list(
  41. load_net, ParameterTuple(load_net.trainable_params()))
  42. print(grad_net(x, y))
  43. def test_assign_only_grad():
  44. class AssignOnlyNet(nn.Cell):
  45. def __init__(self):
  46. super().__init__()
  47. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  48. def construct(self, x, y):
  49. self.z = x
  50. x = x * y
  51. return x
  52. class GradNet(nn.Cell):
  53. def __init__(self, net):
  54. super(GradNet, self).__init__()
  55. self.net = net
  56. self.parameter_tuple = ParameterTuple(self.trainable_params())
  57. def construct(self, x, y):
  58. return grad_all_list(self.net, self.parameter_tuple)(x, y)
  59. assign_net = AssignOnlyNet()
  60. net = GradNet(assign_net)
  61. x = Tensor(np.array([2.0], np.float32))
  62. y = Tensor(np.array([3.0], np.float32))
  63. print(net(x, y))
  64. def test_load_assign_grad():
  65. class AssignNet(nn.Cell):
  66. def __init__(self):
  67. super().__init__()
  68. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  69. self.assign = P.Assign()
  70. def construct(self, x, y):
  71. x = x * self.z
  72. self.assign(self.z, x)
  73. out = y * self.z
  74. return out
  75. class GradNet(nn.Cell):
  76. def __init__(self, net):
  77. super(GradNet, self).__init__()
  78. self.net = net
  79. self.parameter_tuple = ParameterTuple(net.trainable_params())
  80. def construct(self, x, y):
  81. return grad_all_list(self.net, self.parameter_tuple)(x, y)
  82. assign_net = AssignNet()
  83. net = GradNet(assign_net)
  84. x = Tensor(np.array([2.0], np.float32))
  85. y = Tensor(np.array([3.0], np.float32))
  86. print(net(x, y))
  87. def test_insert_gradient_of():
  88. class InsertGradientNet(nn.Cell):
  89. def __init__(self):
  90. super(InsertGradientNet, self).__init__()
  91. self.gather = P.GatherV2()
  92. self.damping = Tensor(np.array([0.03, 0.03], np.float32))
  93. self.cov_step = Parameter(0, name="cov_step", requires_grad=False)
  94. self.freq = Tensor(278, ms.int32)
  95. self.getG = P.InsertGradientOf(self.save_gradient)
  96. def save_gradient(self, dout):
  97. self.cov_step = self.cov_step + self.freq
  98. return dout
  99. def construct(self, x):
  100. self.gather(self.damping, self.cov_step, 0)
  101. out = P.ReLU()(x)
  102. out = self.getG(out)
  103. out = self.getG(out)
  104. return out
  105. net = InsertGradientNet()
  106. input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype(np.float32)
  107. grad_net = grad_all_list(net, ParameterTuple(net.trainable_params()))
  108. print(grad_net(Tensor(input_data)))
  109. @security_off_wrap
  110. def test_user_defined_bprop():
  111. class UserDefinedNet(nn.Cell):
  112. def __init__(self):
  113. super().__init__()
  114. self.print = P.Print()
  115. def construct(self, x, y):
  116. out = x * y
  117. return out
  118. def bprop(self, x, y, out, dout):
  119. self.print(out)
  120. out = x * y
  121. self.print(out)
  122. self.print(dout)
  123. return y, x
  124. class GradNet(nn.Cell):
  125. def __init__(self, net):
  126. super(GradNet, self).__init__()
  127. self.net = net
  128. self.parameter_tuple = ParameterTuple(net.trainable_params())
  129. def construct(self, x, y):
  130. return grad_all_list(self.net, self.parameter_tuple)(x, y)
  131. user_defined_net = UserDefinedNet()
  132. net = GradNet(user_defined_net)
  133. x = Tensor(np.array([2.0], np.float32))
  134. y = Tensor(np.array([3.0], np.float32))
  135. print(net(x, y))
  136. # user defined bprop don't have the same size of parameters with primal's
  137. @security_off_wrap
  138. def test_user_defined_bad_bprop():
  139. class UserDefinedNet(nn.Cell):
  140. def __init__(self):
  141. super().__init__()
  142. self.print = P.Print()
  143. def construct(self, x, y):
  144. out = x * y
  145. return out
  146. def bprop(self, x, out, dout):
  147. self.print(out)
  148. out = x
  149. self.print(out)
  150. self.print(dout)
  151. return x, x
  152. class GradNet(nn.Cell):
  153. def __init__(self, net):
  154. super(GradNet, self).__init__()
  155. self.net = net
  156. self.parameter_tuple = ParameterTuple(net.trainable_params())
  157. def construct(self, x, y):
  158. return grad_all_list(self.net, self.parameter_tuple)(x, y)
  159. user_defined_net = UserDefinedNet()
  160. net = GradNet(user_defined_net)
  161. x = Tensor(np.array([2.0], np.float32))
  162. y = Tensor(np.array([3.0], np.float32))
  163. with pytest.raises(TypeError):
  164. net(x, y)
  165. # shoul compile success and Print in presented in the final function graph.
  166. @security_off_wrap
  167. @pytest.mark.skip(reason="isolated nodes exception")
  168. def test_unused_var():
  169. class UnusedVar(nn.Cell):
  170. def __init__(self):
  171. super().__init__()
  172. self.print = P.Print()
  173. def construct(self, x, y):
  174. shape1 = self.get_shape(x)
  175. out = x
  176. for _ in range(shape1):
  177. out = out + y
  178. return out
  179. def get_shape(self, x):
  180. self.print(x)
  181. _, c, _, _ = F.shape(x)
  182. return c
  183. net = UnusedVar()
  184. x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  185. y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  186. print(net(x, y))
  187. # shoul compile success and Print in presented in the final function graph.
  188. @security_off_wrap
  189. @pytest.mark.skip(reason="isolated nodes exception")
  190. def test_hof_unused_var():
  191. class UnusedVar(nn.Cell):
  192. def __init__(self):
  193. super().__init__()
  194. self.print = P.Print()
  195. def construct(self, x, y):
  196. shape1 = self.hof_get_shape(self.get_shape, x)
  197. out = x
  198. for _ in range(shape1):
  199. out = out + y
  200. return out
  201. def hof_get_shape(self, hof, x):
  202. return hof(x)
  203. def get_shape(self, x):
  204. self.print(x)
  205. _, c, _, _ = F.shape(x)
  206. return c
  207. net = UnusedVar()
  208. x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  209. y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  210. print(net(x, y))
  211. # shoul compile success and Print in presented in the final function graph.
  212. @security_off_wrap
  213. @pytest.mark.skip(reason="isolated nodes exception")
  214. def test_partial_hof_unused_var():
  215. class UnusedVar(nn.Cell):
  216. def __init__(self):
  217. super().__init__()
  218. self.print = P.Print()
  219. def construct(self, x, y):
  220. shape1 = self.hof_get_shape(x)()
  221. out = x
  222. for _ in range(shape1):
  223. out = out + y
  224. return out
  225. def hof_get_shape(self, x):
  226. return F.partial(self.get_shape, x)
  227. def get_shape(self, x):
  228. self.print(x)
  229. _, c, _, _ = F.shape(x)
  230. return c
  231. net = UnusedVar()
  232. x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  233. y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  234. print(net(x, y))
  235. # should compile success without endless loop.
  236. def test_while_if():
  237. class WhileIfNet(nn.Cell):
  238. def __init__(self):
  239. super().__init__()
  240. self.zero = Tensor(np.zeros([1]).astype(np.float32))
  241. self.param = Parameter(Tensor(np.zeros([1]).astype(np.float32)))
  242. def construct(self, idx, end, x):
  243. out = self.zero
  244. while idx < end:
  245. if x < end:
  246. out = out + self.param * 2
  247. else:
  248. out = out + self.param
  249. idx = idx + 1
  250. return out
  251. idx = Tensor(np.array(0), dtype=ms.int32)
  252. end = Tensor(np.array(5), dtype=ms.int32)
  253. x = Tensor(np.zeros([1]).astype(np.float32))
  254. m = WhileIfNet()
  255. m(idx, end, x)
  256. # should compile success without zeros_like_tensor args mismatch, the generated graph files
  257. # should not contain env_getitem or env_setitem.
  258. # InsertGradientOf primitive will make func_graph bprop_construct had BackPropAutoMonad flag set,
  259. # so all graph it used will be checked if any side effect it has, so the hyper_map_zeros_like
  260. # will have U as parameter, but the call site zeros_like(fv) don't have U argument.
  261. def test_grad_fv_and_insert_gradient_of():
  262. class FvAndInsertGradientNet(nn.Cell):
  263. def __init__(self):
  264. super(FvAndInsertGradientNet, self).__init__()
  265. self.gather = P.GatherV2()
  266. self.damping = Tensor(np.array([0.03, 0.03], np.float32))
  267. self.cov_step = Parameter(0, name="cov_step", requires_grad=False)
  268. self.freq = Tensor(278, ms.int32)
  269. self.getG = P.InsertGradientOf(self.save_gradient)
  270. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  271. def save_gradient(self, dout):
  272. self.cov_step = self.cov_step + self.freq
  273. return dout
  274. def construct(self, *inputs):
  275. # fv self.z from construct_wrapper
  276. x, = inputs
  277. self.z = x
  278. # insert_gradient_of
  279. self.gather(self.damping, self.cov_step, 0)
  280. out = self.getG(x)
  281. return out
  282. net = FvAndInsertGradientNet()
  283. input_data = Tensor(np.array([1.0], np.float32))
  284. # if use grad_all_list, the generated graph will have env_setitem
  285. # as gradient for inputs is constant zero, so it will depend on result of grad.
  286. grad_net = grad_by_list(net, ParameterTuple(net.trainable_params()))
  287. print(grad_net(input_data))
  288. # should compile success as cnode with Partial primitive will not bind an additional U monad.
  289. def test_partial_parameter():
  290. z = Parameter(Tensor(np.array([True], np.bool_)), name='z')
  291. class PartialNet(nn.Cell):
  292. def __init__(self, input_z):
  293. super().__init__()
  294. self.input = input_z
  295. def construct(self):
  296. # getattr of all will be convert to Partial
  297. out = self.input.all(axis=())
  298. return out
  299. net = PartialNet(z)
  300. print(net())