You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_auto_monad.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. import numpy as np
  2. import pytest
  3. import mindspore as ms
  4. import mindspore.ops.composite as C
  5. from mindspore import context
  6. import mindspore.nn as nn
  7. from mindspore.ops import operations as P
  8. from mindspore.ops import functional as F
  9. from mindspore import Tensor
  10. from mindspore.common.parameter import Parameter, ParameterTuple
  11. grad_all_list = C.GradOperation(get_all=True, get_by_list=True)
  12. grad_by_list = C.GradOperation(get_by_list=True)
  13. context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
  14. def test_load_grad():
  15. class LoadNet(nn.Cell):
  16. def __init__(self):
  17. super().__init__()
  18. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  19. def construct(self, x, y):
  20. x = x * y * self.z
  21. return x
  22. x = Tensor(np.array([2.0], np.float32))
  23. y = Tensor(np.array([3.0], np.float32))
  24. load_net = LoadNet()
  25. grad_net = grad_all_list(
  26. load_net, ParameterTuple(load_net.trainable_params()))
  27. print(grad_net(x, y))
  28. def test_assign_only_grad():
  29. class AssignOnlyNet(nn.Cell):
  30. def __init__(self):
  31. super().__init__()
  32. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  33. def construct(self, x, y):
  34. self.z = x
  35. x = x * y
  36. return x
  37. class GradNet(nn.Cell):
  38. def __init__(self, net):
  39. super(GradNet, self).__init__()
  40. self.net = net
  41. self.parameter_tuple = ParameterTuple(self.trainable_params())
  42. def construct(self, x, y):
  43. return grad_all_list(self.net, self.parameter_tuple)(x, y)
  44. assign_net = AssignOnlyNet()
  45. net = GradNet(assign_net)
  46. x = Tensor(np.array([2.0], np.float32))
  47. y = Tensor(np.array([3.0], np.float32))
  48. print(net(x, y))
  49. def test_load_assign_grad():
  50. class AssignNet(nn.Cell):
  51. def __init__(self):
  52. super().__init__()
  53. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  54. self.assign = P.Assign()
  55. def construct(self, x, y):
  56. x = x * self.z
  57. self.assign(self.z, x)
  58. out = y * self.z
  59. return out
  60. class GradNet(nn.Cell):
  61. def __init__(self, net):
  62. super(GradNet, self).__init__()
  63. self.net = net
  64. self.parameter_tuple = ParameterTuple(net.trainable_params())
  65. def construct(self, x, y):
  66. return grad_all_list(self.net, self.parameter_tuple)(x, y)
  67. assign_net = AssignNet()
  68. net = GradNet(assign_net)
  69. x = Tensor(np.array([2.0], np.float32))
  70. y = Tensor(np.array([3.0], np.float32))
  71. print(net(x, y))
  72. def test_insert_gradient_of():
  73. class InsertGradientNet(nn.Cell):
  74. def __init__(self):
  75. super(InsertGradientNet, self).__init__()
  76. self.gather = P.GatherV2()
  77. self.damping = Tensor(np.array([0.03, 0.03], np.float32))
  78. self.cov_step = Parameter(0, name="cov_step", requires_grad=False)
  79. self.freq = Tensor(278, ms.int32)
  80. self.getG = P.InsertGradientOf(self.save_gradient)
  81. def save_gradient(self, dout):
  82. self.cov_step = self.cov_step + self.freq
  83. return dout
  84. def construct(self, x):
  85. self.gather(self.damping, self.cov_step, 0)
  86. out = P.ReLU()(x)
  87. out = self.getG(out)
  88. out = self.getG(out)
  89. return out
  90. net = InsertGradientNet()
  91. input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype(np.float32)
  92. grad_net = grad_all_list(net, ParameterTuple(net.trainable_params()))
  93. print(grad_net(Tensor(input_data)))
  94. def test_user_defined_bprop():
  95. class UserDefinedNet(nn.Cell):
  96. def __init__(self):
  97. super().__init__()
  98. self.print = P.Print()
  99. def construct(self, x, y):
  100. out = x * y
  101. return out
  102. def bprop(self, x, y, out, dout):
  103. self.print(out)
  104. out = x * y
  105. self.print(out)
  106. self.print(dout)
  107. return y, x
  108. class GradNet(nn.Cell):
  109. def __init__(self, net):
  110. super(GradNet, self).__init__()
  111. self.net = net
  112. self.parameter_tuple = ParameterTuple(net.trainable_params())
  113. def construct(self, x, y):
  114. return grad_all_list(self.net, self.parameter_tuple)(x, y)
  115. user_defined_net = UserDefinedNet()
  116. net = GradNet(user_defined_net)
  117. x = Tensor(np.array([2.0], np.float32))
  118. y = Tensor(np.array([3.0], np.float32))
  119. print(net(x, y))
  120. # user defined bprop don't have the same size of parameters with primal's
  121. def test_user_defined_bad_bprop():
  122. class UserDefinedNet(nn.Cell):
  123. def __init__(self):
  124. super().__init__()
  125. self.print = P.Print()
  126. def construct(self, x, y):
  127. out = x * y
  128. return out
  129. def bprop(self, x, out, dout):
  130. self.print(out)
  131. out = x
  132. self.print(out)
  133. self.print(dout)
  134. return x, x
  135. class GradNet(nn.Cell):
  136. def __init__(self, net):
  137. super(GradNet, self).__init__()
  138. self.net = net
  139. self.parameter_tuple = ParameterTuple(net.trainable_params())
  140. def construct(self, x, y):
  141. return grad_all_list(self.net, self.parameter_tuple)(x, y)
  142. user_defined_net = UserDefinedNet()
  143. net = GradNet(user_defined_net)
  144. x = Tensor(np.array([2.0], np.float32))
  145. y = Tensor(np.array([3.0], np.float32))
  146. with pytest.raises(TypeError):
  147. net(x, y)
  148. # shoul compile success and Print in presented in the final function graph.
  149. def test_unused_var():
  150. class UnusedVar(nn.Cell):
  151. def __init__(self):
  152. super().__init__()
  153. self.print = P.Print()
  154. def construct(self, x, y):
  155. shape1 = self.get_shape(x)
  156. out = x
  157. for _ in range(shape1):
  158. out = out + y
  159. return out
  160. def get_shape(self, x):
  161. self.print(x)
  162. _, c, _, _ = F.shape(x)
  163. return c
  164. net = UnusedVar()
  165. x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  166. y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  167. print(net(x, y))
  168. # shoul compile success and Print in presented in the final function graph.
  169. def test_hof_unused_var():
  170. class UnusedVar(nn.Cell):
  171. def __init__(self):
  172. super().__init__()
  173. self.print = P.Print()
  174. def construct(self, x, y):
  175. shape1 = self.hof_get_shape(self.get_shape, x)
  176. out = x
  177. for _ in range(shape1):
  178. out = out + y
  179. return out
  180. def hof_get_shape(self, hof, x):
  181. return hof(x)
  182. def get_shape(self, x):
  183. self.print(x)
  184. _, c, _, _ = F.shape(x)
  185. return c
  186. net = UnusedVar()
  187. x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  188. y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  189. print(net(x, y))
  190. # shoul compile success and Print in presented in the final function graph.
  191. def test_partial_hof_unused_var():
  192. class UnusedVar(nn.Cell):
  193. def __init__(self):
  194. super().__init__()
  195. self.print = P.Print()
  196. def construct(self, x, y):
  197. shape1 = self.hof_get_shape(x)()
  198. out = x
  199. for _ in range(shape1):
  200. out = out + y
  201. return out
  202. def hof_get_shape(self, x):
  203. return F.partial(self.get_shape, x)
  204. def get_shape(self, x):
  205. self.print(x)
  206. _, c, _, _ = F.shape(x)
  207. return c
  208. net = UnusedVar()
  209. x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  210. y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
  211. print(net(x, y))
  212. # should compile success without endless loop.
  213. def test_while_if():
  214. class WhileIfNet(nn.Cell):
  215. def __init__(self):
  216. super().__init__()
  217. self.zero = Tensor(np.zeros([1]).astype(np.float32))
  218. self.param = Parameter(Tensor(np.zeros([1]).astype(np.float32)))
  219. def construct(self, idx, end, x):
  220. out = self.zero
  221. while idx < end:
  222. if x < end:
  223. out = out + self.param * 2
  224. else:
  225. out = out + self.param
  226. idx = idx + 1
  227. return out
  228. idx = Tensor(np.array(0), dtype=ms.int32)
  229. end = Tensor(np.array(5), dtype=ms.int32)
  230. x = Tensor(np.zeros([1]).astype(np.float32))
  231. m = WhileIfNet()
  232. m(idx, end, x)
  233. # should compile success without zeros_like_tensor args mismatch, the generated graph files
  234. # should not contain env_getitem or env_setitem.
  235. # InsertGradientOf primitive will make func_graph bprop_construct had BackPropAutoMonad flag set,
  236. # so all graph it used will be checked if any side effect it has, so the hyper_map_zeros_like
  237. # will have U as parameter, but the call site zeros_like(fv) don't have U argument.
  238. def test_grad_fv_and_insert_gradient_of():
  239. class FvAndInsertGradientNet(nn.Cell):
  240. def __init__(self):
  241. super(FvAndInsertGradientNet, self).__init__()
  242. self.gather = P.GatherV2()
  243. self.damping = Tensor(np.array([0.03, 0.03], np.float32))
  244. self.cov_step = Parameter(0, name="cov_step", requires_grad=False)
  245. self.freq = Tensor(278, ms.int32)
  246. self.getG = P.InsertGradientOf(self.save_gradient)
  247. self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
  248. def save_gradient(self, dout):
  249. self.cov_step = self.cov_step + self.freq
  250. return dout
  251. def construct(self, *inputs):
  252. # fv self.z from construct_wrapper
  253. x, = inputs
  254. self.z = x
  255. # insert_gradient_of
  256. self.gather(self.damping, self.cov_step, 0)
  257. out = self.getG(x)
  258. return out
  259. net = FvAndInsertGradientNet()
  260. input_data = Tensor(np.array([1.0], np.float32))
  261. # if use grad_all_list, the generated graph will have env_setitem
  262. # as gradient for inputs is constant zero, so it will depend on result of grad.
  263. grad_net = grad_by_list(net, ParameterTuple(net.trainable_params()))
  264. print(grad_net(input_data))
  265. # should compile success as cnode with Partial primitive will not bind an additional U monad.
  266. def test_partial_parameter():
  267. z = Parameter(Tensor(np.array([True], np.bool_)), name='z')
  268. class PartialNet(nn.Cell):
  269. def __init__(self, input_z):
  270. super().__init__()
  271. self.input = input_z
  272. def construct(self):
  273. # getattr of all will be convert to Partial
  274. out = self.input.all(axis=())
  275. return out
  276. net = PartialNet(z)
  277. print(net())