You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parameter.py 8.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore.context as context
  17. import mindspore.ops.composite as C
  18. from mindspore import Tensor, Parameter
  19. from mindspore.nn import Cell
  20. from mindspore.ops import operations as P
  21. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  22. def test_parser_three_default_mixed_args_subnet():
  23. class SubNetDefaultMixedArgs(Cell):
  24. def __init__(self):
  25. super().__init__()
  26. def construct(self, y, x=3, x1=None, x2=(1, 2)):
  27. if x == 3:
  28. if x1 == None:
  29. return y
  30. return -y
  31. class NetOut(Cell):
  32. def __init__(self):
  33. super(NetOut, self).__init__()
  34. self.net_inside = SubNetDefaultMixedArgs()
  35. def construct(self, x, y=3):
  36. z = self.net_inside(x)
  37. return z
  38. tensor1 = Tensor(np.full((2, 3), 2).astype(np.float32))
  39. tensor2 = Tensor(np.full((3, 2), 4).astype(np.float32))
  40. net = NetOut()
  41. assert net(tensor1, tensor2) == tensor1
  42. def test_net_vararg_kwonlyarg_kwarg():
  43. class FirstNet(Cell):
  44. def __init__(self):
  45. super(FirstNet, self).__init__()
  46. self.net = SecondNet()
  47. def construct(self, x=1, z=2 + 2 + 4, y=3):
  48. c = self.net(22, 33, x, y, z, 2, 3, 4, 5, key1=10, key2=20, key3=30, key4=40)
  49. return c
  50. class SecondNet(Cell):
  51. def __init__(self):
  52. super(SecondNet, self).__init__()
  53. def construct(self, x, y=2, p=5, q=40, *var, key1=1, key2=3, **kwargs):
  54. a = x - y
  55. b = p * q
  56. c = a / b
  57. d = var[0] * var[1] * var[2] * var[3]
  58. e = key1 - key2 - kwargs["key3"] + kwargs["key4"]
  59. return a + b + c + d + e
  60. net = FirstNet()
  61. net()
  62. def test_net_vararg_normal_input():
  63. class FirstNet(Cell):
  64. def __init__(self):
  65. super(FirstNet, self).__init__()
  66. self.net = SecondNet()
  67. def construct(self, x=1, z=2 + 2 + 4, y=3):
  68. c = self.net(22, 33, x, y, z, 2, 3, 4, 5, key1=10, key2=20, key3=30, key4=40)
  69. return c
  70. class SecondNet(Cell):
  71. def __init__(self):
  72. super(SecondNet, self).__init__()
  73. def construct(self, x, y=2, p=5, q=40, *var, key1=1, key2=3, **kwargs):
  74. a = x - y
  75. b = p * q
  76. c = a / b
  77. d = var[0] * var[1] * var[2] * var[3]
  78. e = key1 - key2 - kwargs["key3"] + kwargs["key4"]
  79. return a + b + c + d + e
  80. x = Tensor(np.ones((2, 3, 4), np.int32))
  81. net = FirstNet()
  82. net(x, x, x)
  83. def test_prim_vararg_kwonlyarg():
  84. class FirstNet(Cell):
  85. def __init__(self):
  86. super(FirstNet, self).__init__()
  87. self.max = P.Maximum()
  88. self.min = P.Minimum()
  89. self.net = SecondNet()
  90. self.x = Tensor(np.ones((2, 3, 4), np.float32))
  91. self.y = Tensor(np.ones((2, 3, 4), np.float32))
  92. def construct(self):
  93. a = self.max(self.x, self.y)
  94. b = self.min(self.x, self.y)
  95. t = {"x": a, "y": b}
  96. c = self.net(t["x"], t["y"], a, b, z=a, r=b)
  97. return c
  98. class SecondNet(Cell):
  99. def __init__(self):
  100. super(SecondNet, self).__init__()
  101. self.addN = P.AddN()
  102. self.max = P.Maximum()
  103. self.add = P.TensorAdd()
  104. def construct(self, x, y, *args, z=0, r=1):
  105. c = self.max(args[0], args[1])
  106. d = self.addN(args)
  107. e = self.max(*args)
  108. ret = x + y + c + d + e + z + r
  109. return ret
  110. net = FirstNet()
  111. net()
  112. def test_no_vararg():
  113. class FirstNet(Cell):
  114. def __init__(self):
  115. super(FirstNet, self).__init__()
  116. self.max = P.Maximum()
  117. self.min = P.Minimum()
  118. self.net = SecondNet()
  119. self.x = Tensor(np.ones((2, 3, 4), np.float32))
  120. self.y = Tensor(np.ones((2, 3, 4), np.float32))
  121. def construct(self):
  122. t = {"x": self.x, "y": self.y}
  123. a = self.max(self.x, self.y)
  124. b = self.min(self.x, self.y)
  125. c = self.net(a, b, z=a, r=b)
  126. return c
  127. class SecondNet(Cell):
  128. def __init__(self):
  129. super(SecondNet, self).__init__()
  130. def construct(self, x, y, *, z=0, r=1):
  131. ret = x + y + z + r
  132. return ret
  133. net = FirstNet()
  134. net()
  135. def test_net_variable_and_weights():
  136. class FirstNet(Cell):
  137. def __init__(self):
  138. super(FirstNet, self).__init__()
  139. self.max = P.Maximum()
  140. self.min = P.Minimum()
  141. self.net = SecondNet()
  142. self.x = Tensor(np.ones((3, 4), np.float32))
  143. self.y = Tensor(np.ones((3, 4), np.float32))
  144. self.weight = Parameter(Tensor(np.ones((2, 3, 4)).astype(np.float32)), "w1", requires_grad=True)
  145. def construct(self, *args):
  146. t = (self.x, self.y)
  147. a = self.max(self.x, self.weight)
  148. b = self.min(self.weight, args[0])
  149. c = self.net(a, b, *t)
  150. return c
  151. class SecondNet(Cell):
  152. def __init__(self):
  153. super(SecondNet, self).__init__()
  154. self.addN = P.AddN()
  155. self.max = P.Maximum()
  156. self.add = P.TensorAdd()
  157. self.weight = Parameter(Tensor(np.ones((2, 3, 4), np.float32)), "w2", requires_grad=True)
  158. def construct(self, a, b, *args):
  159. c = self.max(args[0], a)
  160. d = self.addN(args)
  161. ret = a + b + c + d + self.weight
  162. return ret
  163. net = FirstNet()
  164. x = Tensor(np.ones((4,), np.float32))
  165. y = Tensor(np.ones((4,), np.float32))
  166. z = Tensor(np.ones((4,), np.float32))
  167. net(x, y, z)
  168. def test_net_vargs_expand():
  169. class InputBackward(Cell):
  170. """ InputBackward definition """
  171. def __init__(self, network, c1=None, c2=None):
  172. super(InputBackward, self).__init__()
  173. self.network = network
  174. self.network.set_train()
  175. self.grad = C.grad_all_with_sens
  176. self.c1 = c1
  177. self.c2 = c2
  178. def construct(self, *inputs):
  179. return self.grad(self.network)(*inputs)
  180. class AddNet(Cell):
  181. def __init__(self):
  182. super(AddNet, self).__init__()
  183. def construct(self, x, y):
  184. return x + y
  185. net = InputBackward(AddNet())
  186. x = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
  187. y = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
  188. sens = Tensor(np.random.normal(0, 1, [3, 4, 5]).astype(np.float32))
  189. net.set_train()
  190. net(x, y, sens)
  191. def test_mixed_precision_const_parameter():
  192. class NetLoss(Cell):
  193. def __init__(self):
  194. super(NetLoss, self).__init__()
  195. self.shape = P.Shape()
  196. self.up_sample1 = P.ResizeBilinear((14, 14))
  197. self.up_sample2 = P.ResizeBilinear((28, 28))
  198. self.up_sample3 = P.ResizeBilinear((36, 36))
  199. def construct(self, x, y, z, *args):
  200. ret = 0
  201. if args[0] == self.shape(z)[2]:
  202. if args[0] == 14:
  203. ret = self.up_sample1(y) + x
  204. elif args[0] == 28:
  205. ret = self.up_sample2(y) - x
  206. else:
  207. ret = x / y
  208. else:
  209. ret = x * y
  210. ret = ret * z
  211. return ret
  212. class NetMain(Cell):
  213. def __init__(self, loss_fn):
  214. super(NetMain, self).__init__()
  215. self.loss_fn = loss_fn
  216. self.shape = P.Shape()
  217. def construct(self, x, y, z):
  218. size_x = self.shape(x)[2]
  219. size_y = self.shape(y)[2]
  220. ret = self.loss_fn(x, y, z, size_x, size_y)
  221. return ret
  222. loss_fn = NetLoss()
  223. net = NetMain(loss_fn)
  224. net.add_flags_recursive(fp32=True)
  225. x = Tensor(np.ones((1, 3, 28, 28), np.float32))
  226. y = Tensor(np.ones((1, 3, 14, 14), np.float32))
  227. z = Tensor(np.ones((1, 3, 28, 28), np.float32))
  228. out = net(x, y, z)