You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_auto_monad.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. import numpy as np
  2. import pytest
  3. import mindspore as ms
  4. import mindspore.ops.composite as C
  5. from mindspore import context
  6. import mindspore.nn as nn
  7. from mindspore.ops import operations as P
  8. from mindspore.ops import functional as F
  9. from mindspore import Tensor
  10. from mindspore.common.parameter import Parameter, ParameterTuple
  11. grad_all_list = C.GradOperation(get_all=True, get_by_list=True)
  12. grad_by_list = C.GradOperation(get_by_list=True)
  13. context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
  14. def test_load_grad():
  15. class LoadNet(nn.Cell):
  16. def __init__(self):
  17. super().__init__()
  18. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  19. def construct(self, x, y):
  20. x = x * y * self.z
  21. return x
  22. x = Tensor(np.array([2.0], np.float32))
  23. y = Tensor(np.array([3.0], np.float32))
  24. load_net = LoadNet()
  25. grad_net = grad_all_list(
  26. load_net, ParameterTuple(load_net.trainable_params()))
  27. print(grad_net(x, y))
  28. def test_assign_only_grad():
  29. class AssignOnlyNet(nn.Cell):
  30. def __init__(self):
  31. super().__init__()
  32. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  33. def construct(self, x, y):
  34. self.z = x
  35. x = x * y
  36. return x
  37. class GradNet(nn.Cell):
  38. def __init__(self, net):
  39. super(GradNet, self).__init__()
  40. self.net = net
  41. self.parameter_tuple = ParameterTuple(self.trainable_params())
  42. def construct(self, x, y):
  43. return grad_all_list(self.net, self.parameter_tuple)(x, y)
  44. assign_net = AssignOnlyNet()
  45. net = GradNet(assign_net)
  46. x = Tensor(np.array([2.0], np.float32))
  47. y = Tensor(np.array([3.0], np.float32))
  48. print(net(x, y))
  49. def test_load_assign_grad():
  50. class AssignNet(nn.Cell):
  51. def __init__(self):
  52. super().__init__()
  53. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  54. self.assign = P.Assign()
  55. def construct(self, x, y):
  56. x = x * self.z
  57. self.assign(self.z, x)
  58. out = y * self.z
  59. return out
  60. class GradNet(nn.Cell):
  61. def __init__(self, net):
  62. super(GradNet, self).__init__()
  63. self.net = net
  64. self.parameter_tuple = ParameterTuple(net.trainable_params())
  65. def construct(self, x, y):
  66. return grad_all_list(self.net, self.parameter_tuple)(x, y)
  67. assign_net = AssignNet()
  68. net = GradNet(assign_net)
  69. x = Tensor(np.array([2.0], np.float32))
  70. y = Tensor(np.array([3.0], np.float32))
  71. print(net(x, y))
  72. def test_insert_gradient_of():
  73. class InsertGradientNet(nn.Cell):
  74. def __init__(self):
  75. super(InsertGradientNet, self).__init__()
  76. self.gather = P.GatherV2()
  77. self.damping = Tensor(np.array([0.03, 0.03], np.float32))
  78. self.cov_step = Parameter(0, name="cov_step", requires_grad=False)
  79. self.freq = Tensor(278, ms.int32)
  80. self.getG = P.InsertGradientOf(self.save_gradient)
  81. def save_gradient(self, dout):
  82. self.cov_step = self.cov_step + self.freq
  83. return dout
  84. def construct(self, x):
  85. self.gather(self.damping, self.cov_step, 0)
  86. out = P.ReLU()(x)
  87. out = self.getG(out)
  88. out = self.getG(out)
  89. return out
  90. net = InsertGradientNet()
  91. input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype(np.float32)
  92. grad_net = grad_all_list(net, ParameterTuple(net.trainable_params()))
  93. print(grad_net(Tensor(input_data)))
  94. def test_user_defined_bprop():
  95. class UserDefinedNet(nn.Cell):
  96. def __init__(self):
  97. super().__init__()
  98. self.print = P.Print()
  99. def construct(self, x, y):
  100. out = x * y
  101. return out
  102. def bprop(self, x, y, out, dout):
  103. self.print(out)
  104. out = x * y
  105. self.print(out)
  106. self.print(dout)
  107. return y, x
  108. class GradNet(nn.Cell):
  109. def __init__(self, net):
  110. super(GradNet, self).__init__()
  111. self.net = net
  112. self.parameter_tuple = ParameterTuple(net.trainable_params())
  113. def construct(self, x, y):
  114. return grad_all_list(self.net, self.parameter_tuple)(x, y)
  115. user_defined_net = UserDefinedNet()
  116. net = GradNet(user_defined_net)
  117. x = Tensor(np.array([2.0], np.float32))
  118. y = Tensor(np.array([3.0], np.float32))
  119. print(net(x, y))
  120. # user defined bprop don't have the same size of parameters with primal's
  121. def test_user_defined_bad_bprop():
  122. class UserDefinedNet(nn.Cell):
  123. def __init__(self):
  124. super().__init__()
  125. self.print = P.Print()
  126. def construct(self, x, y):
  127. out = x * y
  128. return out
  129. def bprop(self, x, out, dout):
  130. self.print(out)
  131. out = x
  132. self.print(out)
  133. self.print(dout)
  134. return x, x
  135. class GradNet(nn.Cell):
  136. def __init__(self, net):
  137. super(GradNet, self).__init__()
  138. self.net = net
  139. self.parameter_tuple = ParameterTuple(net.trainable_params())
  140. def construct(self, x, y):
  141. return grad_all_list(self.net, self.parameter_tuple)(x, y)
  142. user_defined_net = UserDefinedNet()
  143. net = GradNet(user_defined_net)
  144. x = Tensor(np.array([2.0], np.float32))
  145. y = Tensor(np.array([3.0], np.float32))
  146. with pytest.raises(TypeError):
  147. net(x, y)
  148. # shoul compile success and Print in presented in the final function graph.
  149. @pytest.mark.skip(reason="isolated nodes exception")
  150. def test_unused_var():
  151. class UnusedVar(nn.Cell):
  152. def __init__(self):
  153. super().__init__()
  154. self.print = P.Print()
  155. def construct(self, x, y):
  156. shape1 = self.get_shape(x)
  157. out = x
  158. for _ in range(shape1):
  159. out = out + y
  160. return out
  161. def get_shape(self, x):
  162. self.print(x)
  163. _, c, _, _ = F.shape(x)
  164. return c
  165. net = UnusedVar()
  166. x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  167. y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  168. print(net(x, y))
  169. # shoul compile success and Print in presented in the final function graph.
  170. @pytest.mark.skip(reason="isolated nodes exception")
  171. def test_hof_unused_var():
  172. class UnusedVar(nn.Cell):
  173. def __init__(self):
  174. super().__init__()
  175. self.print = P.Print()
  176. def construct(self, x, y):
  177. shape1 = self.hof_get_shape(self.get_shape, x)
  178. out = x
  179. for _ in range(shape1):
  180. out = out + y
  181. return out
  182. def hof_get_shape(self, hof, x):
  183. return hof(x)
  184. def get_shape(self, x):
  185. self.print(x)
  186. _, c, _, _ = F.shape(x)
  187. return c
  188. net = UnusedVar()
  189. x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  190. y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  191. print(net(x, y))
  192. # shoul compile success and Print in presented in the final function graph.
  193. @pytest.mark.skip(reason="isolated nodes exception")
  194. def test_partial_hof_unused_var():
  195. class UnusedVar(nn.Cell):
  196. def __init__(self):
  197. super().__init__()
  198. self.print = P.Print()
  199. def construct(self, x, y):
  200. shape1 = self.hof_get_shape(x)()
  201. out = x
  202. for _ in range(shape1):
  203. out = out + y
  204. return out
  205. def hof_get_shape(self, x):
  206. return F.partial(self.get_shape, x)
  207. def get_shape(self, x):
  208. self.print(x)
  209. _, c, _, _ = F.shape(x)
  210. return c
  211. net = UnusedVar()
  212. x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  213. y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  214. print(net(x, y))
  215. # should compile success without endless loop.
  216. def test_while_if():
  217. class WhileIfNet(nn.Cell):
  218. def __init__(self):
  219. super().__init__()
  220. self.zero = Tensor(np.zeros([1]).astype(np.float32))
  221. self.param = Parameter(Tensor(np.zeros([1]).astype(np.float32)))
  222. def construct(self, idx, end, x):
  223. out = self.zero
  224. while idx < end:
  225. if x < end:
  226. out = out + self.param * 2
  227. else:
  228. out = out + self.param
  229. idx = idx + 1
  230. return out
  231. idx = Tensor(np.array(0), dtype=ms.int32)
  232. end = Tensor(np.array(5), dtype=ms.int32)
  233. x = Tensor(np.zeros([1]).astype(np.float32))
  234. m = WhileIfNet()
  235. m(idx, end, x)
  236. # should compile success without zeros_like_tensor args mismatch, the generated graph files
  237. # should not contain env_getitem or env_setitem.
  238. # InsertGradientOf primitive will make func_graph bprop_construct had BackPropAutoMonad flag set,
  239. # so all graph it used will be checked if any side effect it has, so the hyper_map_zeros_like
  240. # will have U as parameter, but the call site zeros_like(fv) don't have U argument.
  241. def test_grad_fv_and_insert_gradient_of():
  242. class FvAndInsertGradientNet(nn.Cell):
  243. def __init__(self):
  244. super(FvAndInsertGradientNet, self).__init__()
  245. self.gather = P.GatherV2()
  246. self.damping = Tensor(np.array([0.03, 0.03], np.float32))
  247. self.cov_step = Parameter(0, name="cov_step", requires_grad=False)
  248. self.freq = Tensor(278, ms.int32)
  249. self.getG = P.InsertGradientOf(self.save_gradient)
  250. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  251. def save_gradient(self, dout):
  252. self.cov_step = self.cov_step + self.freq
  253. return dout
  254. def construct(self, *inputs):
  255. # fv self.z from construct_wrapper
  256. x, = inputs
  257. self.z = x
  258. # insert_gradient_of
  259. self.gather(self.damping, self.cov_step, 0)
  260. out = self.getG(x)
  261. return out
  262. net = FvAndInsertGradientNet()
  263. input_data = Tensor(np.array([1.0], np.float32))
  264. # if use grad_all_list, the generated graph will have env_setitem
  265. # as gradient for inputs is constant zero, so it will depend on result of grad.
  266. grad_net = grad_by_list(net, ParameterTuple(net.trainable_params()))
  267. print(grad_net(input_data))
  268. # should compile success as cnode with Partial primitive will not bind an additional U monad.
  269. def test_partial_parameter():
  270. z = Parameter(Tensor(np.array([True], np.bool_)), name='z')
  271. class PartialNet(nn.Cell):
  272. def __init__(self, input_z):
  273. super().__init__()
  274. self.input = input_z
  275. def construct(self):
  276. # getattr of all will be convert to Partial
  277. out = self.input.all(axis=())
  278. return out
  279. net = PartialNet(z)
  280. print(net())