You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parameter.py 10 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore.context as context
  17. import mindspore.ops.composite as C
  18. from mindspore import Tensor, Parameter
  19. from mindspore import nn
  20. from mindspore.nn import Cell
  21. from mindspore.ops import operations as P
  22. context.set_context(mode=context.GRAPH_MODE)
  23. grad_all = C.GradOperation(get_all=True)
  24. grad_all_with_sens = C.GradOperation(sens_param=True)
  25. def test_parser_three_default_mixed_args_subnet():
  26. class SubNetDefaultMixedArgs(Cell):
  27. def __init__(self):
  28. super().__init__()
  29. def construct(self, y, x=3, x1=None, x2=(1, 2)):
  30. if x == 3:
  31. if x1 == None:
  32. return y
  33. return -y
  34. class NetOut(Cell):
  35. def __init__(self):
  36. super(NetOut, self).__init__()
  37. self.net_inside = SubNetDefaultMixedArgs()
  38. def construct(self, x, y=3):
  39. z = self.net_inside(x)
  40. return z
  41. tensor1 = Tensor(np.full((2, 3), 2).astype(np.float32))
  42. tensor2 = Tensor(np.full((3, 2), 4).astype(np.float32))
  43. net = NetOut()
  44. assert np.all(net(tensor1, tensor2).asnumpy() == tensor1.asnumpy())
  45. # pylint: disable=keyword-arg-before-vararg
  46. def test_net_vararg_kwonlyarg_kwarg():
  47. class FirstNet(Cell):
  48. def __init__(self):
  49. super(FirstNet, self).__init__()
  50. self.net = SecondNet()
  51. def construct(self, x=1, z=2 + 2 + 4, y=3):
  52. c = self.net(22, 33, x, y, z, 2, 3, 4, 5, key1=10, key2=20, key3=30, key4=40)
  53. return c
  54. class SecondNet(Cell):
  55. def __init__(self):
  56. super(SecondNet, self).__init__()
  57. def construct(self, x, y=2, p=5, q=40, *var, key1=1, key2=3, **kwargs):
  58. a = x - y
  59. b = p * q
  60. c = a / b
  61. d = var[0] * var[1] * var[2] * var[3]
  62. e = key1 - key2 - kwargs["key3"] + kwargs["key4"]
  63. return a + b + c + d + e
  64. net = FirstNet()
  65. net()
  66. # pylint: disable=keyword-arg-before-vararg
  67. def test_net_vararg_normal_input():
  68. class FirstNet(Cell):
  69. def __init__(self):
  70. super(FirstNet, self).__init__()
  71. self.net = SecondNet()
  72. def construct(self, x=1, z=2 + 2 + 4, y=3):
  73. c = self.net(22, 33, x, y, z, 2, 3, 4, 5, key1=10, key2=20, key3=30, key4=40)
  74. return c
  75. class SecondNet(Cell):
  76. def __init__(self):
  77. super(SecondNet, self).__init__()
  78. def construct(self, x, y=2, p=5, q=40, *var, key1=1, key2=3, **kwargs):
  79. a = x - y
  80. b = p * q
  81. c = a / b
  82. d = var[0] * var[1] * var[2] * var[3]
  83. e = key1 - key2 - kwargs["key3"] + kwargs["key4"]
  84. return a + b + c + d + e
  85. x = Tensor(np.ones((2, 3, 4), np.int32))
  86. net = FirstNet()
  87. net(x, x, x)
  88. def test_prim_vararg_kwonlyarg():
  89. class FirstNet(Cell):
  90. def __init__(self):
  91. super(FirstNet, self).__init__()
  92. self.max = P.Maximum()
  93. self.min = P.Minimum()
  94. self.net = SecondNet()
  95. self.x = Tensor(np.ones((2, 3, 4), np.float32))
  96. self.y = Tensor(np.ones((2, 3, 4), np.float32))
  97. def construct(self):
  98. a = self.max(self.x, self.y)
  99. b = self.min(self.x, self.y)
  100. t = {"x": a, "y": b}
  101. c = self.net(t["x"], t["y"], a, b, z=a, r=b)
  102. return c
  103. class SecondNet(Cell):
  104. def __init__(self):
  105. super(SecondNet, self).__init__()
  106. self.addN = P.AddN()
  107. self.max = P.Maximum()
  108. self.add = P.Add()
  109. def construct(self, x, y, *args, z=0, r=1):
  110. c = self.max(args[0], args[1])
  111. d = self.addN(args)
  112. e = self.max(*args)
  113. ret = x + y + c + d + e + z + r
  114. return ret
  115. net = FirstNet()
  116. net()
  117. def test_no_vararg():
  118. class FirstNet(Cell):
  119. def __init__(self):
  120. super(FirstNet, self).__init__()
  121. self.max = P.Maximum()
  122. self.min = P.Minimum()
  123. self.net = SecondNet()
  124. self.x = Tensor(np.ones((2, 3, 4), np.float32))
  125. self.y = Tensor(np.ones((2, 3, 4), np.float32))
  126. def construct(self):
  127. t = {"x": self.x, "y": self.y}
  128. a = self.max(self.x, self.y)
  129. b = self.min(self.x, self.y)
  130. c = self.net(a, b, z=a, r=b)
  131. return c
  132. class SecondNet(Cell):
  133. def __init__(self):
  134. super(SecondNet, self).__init__()
  135. def construct(self, x, y, *, z=0, r=1):
  136. ret = x + y + z + r
  137. return ret
  138. net = FirstNet()
  139. net()
  140. def test_net_variable_and_weights():
  141. class FirstNet(Cell):
  142. def __init__(self):
  143. super(FirstNet, self).__init__()
  144. self.max = P.Maximum()
  145. self.min = P.Minimum()
  146. self.net = SecondNet()
  147. self.x = Tensor(np.ones((3, 4), np.float32))
  148. self.y = Tensor(np.ones((3, 4), np.float32))
  149. self.weight = Parameter(Tensor(np.ones((2, 3, 4)).astype(np.float32)), "w1", requires_grad=True)
  150. def construct(self, *args):
  151. t = (self.x, self.y)
  152. a = self.max(self.x, self.weight)
  153. b = self.min(self.weight, args[0])
  154. c = self.net(a, b, *t)
  155. return c
  156. class SecondNet(Cell):
  157. def __init__(self):
  158. super(SecondNet, self).__init__()
  159. self.addN = P.AddN()
  160. self.max = P.Maximum()
  161. self.add = P.Add()
  162. self.weight = Parameter(Tensor(np.ones((2, 3, 4), np.float32)), "w2", requires_grad=True)
  163. def construct(self, a, b, *args):
  164. c = self.max(args[0], a)
  165. d = self.addN(args)
  166. ret = a + b + c + d + self.weight
  167. return ret
  168. net = FirstNet()
  169. x = Tensor(np.ones((4,), np.float32))
  170. y = Tensor(np.ones((4,), np.float32))
  171. z = Tensor(np.ones((4,), np.float32))
  172. net(x, y, z)
  173. def test_net_vargs_expand():
  174. class InputBackward(Cell):
  175. """ InputBackward definition """
  176. def __init__(self, network, c1=None, c2=None):
  177. super(InputBackward, self).__init__()
  178. self.network = network
  179. self.network.set_train()
  180. self.grad = grad_all_with_sens
  181. self.c1 = c1
  182. self.c2 = c2
  183. def construct(self, *inputs):
  184. return self.grad(self.network)(*inputs)
  185. class AddNet(Cell):
  186. def __init__(self):
  187. super(AddNet, self).__init__()
  188. def construct(self, x, y):
  189. return x + y
  190. net = InputBackward(AddNet())
  191. x = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
  192. y = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
  193. sens = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
  194. net.set_train()
  195. net(x, y, sens)
  196. def test_mixed_precision_const_parameter():
  197. class NetLoss(Cell):
  198. def __init__(self):
  199. super(NetLoss, self).__init__()
  200. self.shape = P.Shape()
  201. self.up_sample1 = P.ResizeBilinear((14, 14))
  202. self.up_sample2 = P.ResizeBilinear((28, 28))
  203. self.up_sample3 = P.ResizeBilinear((36, 36))
  204. def construct(self, x, y, z, *args):
  205. ret = 0
  206. if args[0] == self.shape(z)[2]:
  207. if args[0] == 14:
  208. ret = self.up_sample1(y) + x
  209. elif args[0] == 28:
  210. ret = self.up_sample2(y) - x
  211. else:
  212. ret = x / y
  213. else:
  214. ret = x * y
  215. ret = ret * z
  216. return ret
  217. class NetMain(Cell):
  218. def __init__(self, loss_fn):
  219. super(NetMain, self).__init__()
  220. self.loss_fn = loss_fn
  221. self.shape = P.Shape()
  222. def construct(self, x, y, z):
  223. size_x = self.shape(x)[2]
  224. size_y = self.shape(y)[2]
  225. ret = self.loss_fn(x, y, z, size_x, size_y)
  226. return ret
  227. loss_fn = NetLoss()
  228. net = NetMain(loss_fn)
  229. net.add_flags_recursive(fp32=True)
  230. x = Tensor(np.ones((1, 3, 28, 28), np.float32))
  231. y = Tensor(np.ones((1, 3, 14, 14), np.float32))
  232. z = Tensor(np.ones((1, 3, 28, 28), np.float32))
  233. _ = net(x, y, z)
  234. def test_pass_args_by_key_ward_way():
  235. class KeyWardNet(Cell):
  236. def __init__(self):
  237. super(KeyWardNet, self).__init__()
  238. def construct(self, x, y, z):
  239. return x + y - z
  240. class GradNet(Cell):
  241. def __init__(self, net):
  242. super(GradNet, self).__init__()
  243. self.grad = C.GradOperation(get_all=True, sens_param=True)
  244. self.net = net
  245. self.sens = Tensor(np.ones((3, 3, 4), np.float32))
  246. def construct(self, x, y, z, sens):
  247. return self.grad(self.net)(x, y, z, sens)
  248. x = Tensor(np.ones((1, 3, 4), np.float32))
  249. y = Tensor(np.ones((1, 3, 4), np.float32))
  250. z = Tensor(np.ones((3, 3, 4), np.float32))
  251. net = KeyWardNet()
  252. net(x, z=z, y=y)
  253. grad_net = GradNet(net)
  254. sens = Tensor(np.ones((3, 3, 4), np.float32))
  255. grad_net(x, y=y, z=z, sens=sens)
  256. def test_none_input():
  257. """
  258. Feature: Net's inputs
  259. Description: Support none input for the outermost net
  260. Expectation: no error
  261. """
  262. class Net(Cell):
  263. def __init__(self):
  264. super(Net, self).__init__()
  265. self.op = nn.ResizeBilinear()
  266. def construct(self, a, b, c, d):
  267. return self.op(a, b, c, d)
  268. x = Tensor(np.array([1, 2, 3, 4]).astype(np.float32).reshape((1, 1, 2, 2,)))
  269. net = Net()
  270. net(x, (4, 4), None, True)