You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parameter.py 9.9 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore.context as context
  17. import mindspore.ops.composite as C
  18. from mindspore import Tensor, Parameter
  19. from mindspore.nn import Cell
  20. from mindspore.ops import operations as P
  21. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  22. grad_all = C.GradOperation(get_all=True)
  23. grad_all_with_sens = C.GradOperation(sens_param=True)
  24. def test_parser_three_default_mixed_args_subnet():
  25. class SubNetDefaultMixedArgs(Cell):
  26. def __init__(self):
  27. super().__init__()
  28. def construct(self, y, x=3, x1=None, x2=(1, 2)):
  29. if x == 3:
  30. if x1 == None:
  31. return y
  32. return -y
  33. class NetOut(Cell):
  34. def __init__(self):
  35. super(NetOut, self).__init__()
  36. self.net_inside = SubNetDefaultMixedArgs()
  37. def construct(self, x, y=3):
  38. z = self.net_inside(x)
  39. return z
  40. tensor1 = Tensor(np.full((2, 3), 2).astype(np.float32))
  41. tensor2 = Tensor(np.full((3, 2), 4).astype(np.float32))
  42. net = NetOut()
  43. assert np.all(net(tensor1, tensor2).asnumpy() == tensor1.asnumpy())
  44. # pylint: disable=keyword-arg-before-vararg
  45. def test_net_vararg_kwonlyarg_kwarg():
  46. class FirstNet(Cell):
  47. def __init__(self):
  48. super(FirstNet, self).__init__()
  49. self.net = SecondNet()
  50. def construct(self, x=1, z=2 + 2 + 4, y=3):
  51. c = self.net(22, 33, x, y, z, 2, 3, 4, 5, key1=10, key2=20, key3=30, key4=40)
  52. return c
  53. class SecondNet(Cell):
  54. def __init__(self):
  55. super(SecondNet, self).__init__()
  56. def construct(self, x, y=2, p=5, q=40, *var, key1=1, key2=3, **kwargs):
  57. a = x - y
  58. b = p * q
  59. c = a / b
  60. d = var[0] * var[1] * var[2] * var[3]
  61. e = key1 - key2 - kwargs["key3"] + kwargs["key4"]
  62. return a + b + c + d + e
  63. net = FirstNet()
  64. net()
  65. # pylint: disable=keyword-arg-before-vararg
  66. def test_net_vararg_normal_input():
  67. class FirstNet(Cell):
  68. def __init__(self):
  69. super(FirstNet, self).__init__()
  70. self.net = SecondNet()
  71. def construct(self, x=1, z=2 + 2 + 4, y=3):
  72. c = self.net(22, 33, x, y, z, 2, 3, 4, 5, key1=10, key2=20, key3=30, key4=40)
  73. return c
  74. class SecondNet(Cell):
  75. def __init__(self):
  76. super(SecondNet, self).__init__()
  77. def construct(self, x, y=2, p=5, q=40, *var, key1=1, key2=3, **kwargs):
  78. a = x - y
  79. b = p * q
  80. c = a / b
  81. d = var[0] * var[1] * var[2] * var[3]
  82. e = key1 - key2 - kwargs["key3"] + kwargs["key4"]
  83. return a + b + c + d + e
  84. x = Tensor(np.ones((2, 3, 4), np.int32))
  85. net = FirstNet()
  86. net(x, x, x)
  87. def test_prim_vararg_kwonlyarg():
  88. class FirstNet(Cell):
  89. def __init__(self):
  90. super(FirstNet, self).__init__()
  91. self.max = P.Maximum()
  92. self.min = P.Minimum()
  93. self.net = SecondNet()
  94. self.x = Tensor(np.ones((2, 3, 4), np.float32))
  95. self.y = Tensor(np.ones((2, 3, 4), np.float32))
  96. def construct(self):
  97. a = self.max(self.x, self.y)
  98. b = self.min(self.x, self.y)
  99. t = {"x": a, "y": b}
  100. c = self.net(t["x"], t["y"], a, b, z=a, r=b)
  101. return c
  102. class SecondNet(Cell):
  103. def __init__(self):
  104. super(SecondNet, self).__init__()
  105. self.addN = P.AddN()
  106. self.max = P.Maximum()
  107. self.add = P.Add()
  108. def construct(self, x, y, *args, z=0, r=1):
  109. c = self.max(args[0], args[1])
  110. d = self.addN(args)
  111. e = self.max(*args)
  112. ret = x + y + c + d + e + z + r
  113. return ret
  114. net = FirstNet()
  115. net()
  116. def test_no_vararg():
  117. class FirstNet(Cell):
  118. def __init__(self):
  119. super(FirstNet, self).__init__()
  120. self.max = P.Maximum()
  121. self.min = P.Minimum()
  122. self.net = SecondNet()
  123. self.x = Tensor(np.ones((2, 3, 4), np.float32))
  124. self.y = Tensor(np.ones((2, 3, 4), np.float32))
  125. def construct(self):
  126. t = {"x": self.x, "y": self.y}
  127. a = self.max(self.x, self.y)
  128. b = self.min(self.x, self.y)
  129. c = self.net(a, b, z=a, r=b)
  130. return c
  131. class SecondNet(Cell):
  132. def __init__(self):
  133. super(SecondNet, self).__init__()
  134. def construct(self, x, y, *, z=0, r=1):
  135. ret = x + y + z + r
  136. return ret
  137. net = FirstNet()
  138. net()
  139. def test_net_variable_and_weights():
  140. class FirstNet(Cell):
  141. def __init__(self):
  142. super(FirstNet, self).__init__()
  143. self.max = P.Maximum()
  144. self.min = P.Minimum()
  145. self.net = SecondNet()
  146. self.x = Tensor(np.ones((3, 4), np.float32))
  147. self.y = Tensor(np.ones((3, 4), np.float32))
  148. self.weight = Parameter(Tensor(np.ones((2, 3, 4)).astype(np.float32)), "w1", requires_grad=True)
  149. def construct(self, *args):
  150. t = (self.x, self.y)
  151. a = self.max(self.x, self.weight)
  152. b = self.min(self.weight, args[0])
  153. c = self.net(a, b, *t)
  154. return c
  155. class SecondNet(Cell):
  156. def __init__(self):
  157. super(SecondNet, self).__init__()
  158. self.addN = P.AddN()
  159. self.max = P.Maximum()
  160. self.add = P.Add()
  161. self.weight = Parameter(Tensor(np.ones((2, 3, 4), np.float32)), "w2", requires_grad=True)
  162. def construct(self, a, b, *args):
  163. c = self.max(args[0], a)
  164. d = self.addN(args)
  165. ret = a + b + c + d + self.weight
  166. return ret
  167. net = FirstNet()
  168. x = Tensor(np.ones((4,), np.float32))
  169. y = Tensor(np.ones((4,), np.float32))
  170. z = Tensor(np.ones((4,), np.float32))
  171. net(x, y, z)
  172. def test_net_vargs_expand():
  173. class InputBackward(Cell):
  174. """ InputBackward definition """
  175. def __init__(self, network, c1=None, c2=None):
  176. super(InputBackward, self).__init__()
  177. self.network = network
  178. self.network.set_train()
  179. self.grad = grad_all_with_sens
  180. self.c1 = c1
  181. self.c2 = c2
  182. def construct(self, *inputs):
  183. return self.grad(self.network)(*inputs)
  184. class AddNet(Cell):
  185. def __init__(self):
  186. super(AddNet, self).__init__()
  187. def construct(self, x, y):
  188. return x + y
  189. net = InputBackward(AddNet())
  190. x = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
  191. y = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
  192. sens = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
  193. net.set_train()
  194. net(x, y, sens)
  195. def test_mixed_precision_const_parameter():
  196. class NetLoss(Cell):
  197. def __init__(self):
  198. super(NetLoss, self).__init__()
  199. self.shape = P.Shape()
  200. self.up_sample1 = P.ResizeBilinear((14, 14))
  201. self.up_sample2 = P.ResizeBilinear((28, 28))
  202. self.up_sample3 = P.ResizeBilinear((36, 36))
  203. def construct(self, x, y, z, *args):
  204. ret = 0
  205. if args[0] == self.shape(z)[2]:
  206. if args[0] == 14:
  207. ret = self.up_sample1(y) + x
  208. elif args[0] == 28:
  209. ret = self.up_sample2(y) - x
  210. else:
  211. ret = x / y
  212. else:
  213. ret = x * y
  214. ret = ret * z
  215. return ret
  216. class NetMain(Cell):
  217. def __init__(self, loss_fn):
  218. super(NetMain, self).__init__()
  219. self.loss_fn = loss_fn
  220. self.shape = P.Shape()
  221. def construct(self, x, y, z):
  222. size_x = self.shape(x)[2]
  223. size_y = self.shape(y)[2]
  224. ret = self.loss_fn(x, y, z, size_x, size_y)
  225. return ret
  226. loss_fn = NetLoss()
  227. net = NetMain(loss_fn)
  228. net.add_flags_recursive(fp32=True)
  229. x = Tensor(np.ones((1, 3, 28, 28), np.float32))
  230. y = Tensor(np.ones((1, 3, 14, 14), np.float32))
  231. z = Tensor(np.ones((1, 3, 28, 28), np.float32))
  232. _ = net(x, y, z)
  233. def test_pass_args_by_key_ward_way():
  234. class KeyWardNet(Cell):
  235. def __init__(self):
  236. super(KeyWardNet, self).__init__()
  237. def construct(self, x, y, z):
  238. return x + y - z
  239. class GradNet(Cell):
  240. def __init__(self, net):
  241. super(GradNet, self).__init__()
  242. self.grad = C.GradOperation(get_all=True, sens_param=True)
  243. self.net = net
  244. self.sens = Tensor(np.ones((3, 3, 4), np.float32))
  245. def construct(self, x, y, z, sens):
  246. return self.grad(self.net)(x, y, z, sens)
  247. x = Tensor(np.ones((1, 3, 4), np.float32))
  248. y = Tensor(np.ones((1, 3, 4), np.float32))
  249. z = Tensor(np.ones((3, 3, 4), np.float32))
  250. net = KeyWardNet()
  251. net(x, z=z, y=y)
  252. grad_net = GradNet(net)
  253. sens = Tensor(np.ones((3, 3, 4), np.float32))
  254. grad_net(x, y=y, z=z, sens=sens)