You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_ops_attr_infer.py 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test nn ops """
  16. import numpy as np
  17. from numpy.random import normal
  18. import mindspore.nn as nn
  19. import mindspore.context as context
  20. from mindspore.ops.composite import core
  21. from mindspore.common.api import ms_function
  22. from mindspore import Tensor
  23. from mindspore.ops import functional as F
  24. from mindspore.ops import prim_attr_register, PrimitiveWithInfer
  25. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  26. class FakeOp(PrimitiveWithInfer):
  27. @prim_attr_register
  28. def __init__(self):
  29. """"""
  30. def infer_shape(self, x, y):
  31. self.second_shape = y
  32. self.add_prim_attr("second_shape", y)
  33. return x
  34. def infer_dtype(self, x, y):
  35. return x
  36. # test the normal case that should generate independent primitive because of different
  37. # generated attributes after inference
  38. def test_conv2d_same_primitive():
  39. class Conv2DSameNet(nn.Cell):
  40. def __init__(self):
  41. super(Conv2DSameNet, self).__init__()
  42. self.conv1 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
  43. self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
  44. def construct(self, x, y):
  45. r1 = self.conv1(x)
  46. r2 = self.conv2(y)
  47. return (r1, r2)
  48. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  49. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  50. net = Conv2DSameNet()
  51. net(t1, t2)
  52. # test free variable function list as parameter
  53. def test_remove_and_fv_2():
  54. @core(loop_can_uroll=True)
  55. def inner_loop(x, input_data, fv_func_list):
  56. ret = ()
  57. for fv_fn in fv_func_list:
  58. ele = fv_fn(input_data)
  59. ret += (ele,)
  60. return ret
  61. @ms_function
  62. def out_loop(input1, input_data):
  63. ret = ()
  64. def fv_func1(y):
  65. return input1 * y
  66. def fv_func2(y):
  67. return input1 - y
  68. fv_func_list = [fv_func1, fv_func2]
  69. ele0 = inner_loop(input1, input_data[0], fv_func_list)
  70. ele1 = inner_loop(input1, input_data[1], fv_func_list)
  71. ret = (ele0, ele1)
  72. return ret
  73. input_data = (Tensor(normal(0, 0.1, (3, 3))), Tensor(normal(0, 0.1, (3, 1))))
  74. input1 = Tensor(normal(0, 0.1, (3, 3)))
  75. out_loop(input1, input_data)
  76. # test cell as high order argument
  77. # The graph with free variables used as argument is not supported yet
  78. # because of the limit of inference specialize system
  79. def test_conv2d_op_with_argi_1():
  80. class Conv2dNet(nn.Cell):
  81. def __init__(self):
  82. super(Conv2dNet, self).__init__()
  83. def construct(self, op, x):
  84. return op(x)
  85. class OpsNet(nn.Cell):
  86. def __init__(self, net):
  87. super(OpsNet, self).__init__()
  88. self.opnet = net
  89. self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
  90. def construct(self, x, y):
  91. conv_op = self.conv2
  92. a = self.opnet(conv_op, x)
  93. b = self.opnet(conv_op, y)
  94. return (a, b)
  95. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  96. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  97. net = OpsNet(Conv2dNet())
  98. net(t1, t2)
  99. def test_conv2d_op_with_arg():
  100. class FackOpNet(nn.Cell):
  101. def __init__(self):
  102. super(FackOpNet, self).__init__()
  103. self.op = FakeOp()
  104. def construct(self, x, y):
  105. return self.op(x, y)
  106. class OpNet(nn.Cell):
  107. def __init__(self):
  108. super(OpNet, self).__init__()
  109. def construct(self, op, x, y):
  110. return op(x, y)
  111. class OpsNet(nn.Cell):
  112. def __init__(self, net):
  113. super(OpsNet, self).__init__()
  114. self.opnet = net
  115. self.op = FackOpNet()
  116. def construct(self, x, y):
  117. op = self.op
  118. a = self.opnet(op, x, y)
  119. b = self.opnet(op, y, x)
  120. return (a, b)
  121. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  122. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  123. net = OpsNet(OpNet())
  124. net(t1, t2)
  125. def test_conv2d_op_with_arg_same_input():
  126. class FackOpNet(nn.Cell):
  127. def __init__(self):
  128. super(FackOpNet, self).__init__()
  129. self.op = FakeOp()
  130. def construct(self, x, y):
  131. return self.op(x, y)
  132. class OpNet(nn.Cell):
  133. def __init__(self):
  134. super(OpNet, self).__init__()
  135. def construct(self, op, x, y):
  136. return op(x, y)
  137. class OpsNet(nn.Cell):
  138. def __init__(self, net):
  139. super(OpsNet, self).__init__()
  140. self.opnet = net
  141. self.op = FackOpNet()
  142. def construct(self, x, y):
  143. op = self.op
  144. a = self.opnet(op, x, x)
  145. b = self.opnet(op, y, x)
  146. return (a, b)
  147. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  148. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  149. net = OpsNet(OpNet())
  150. net(t1, t2)
  151. # test op with partial
  152. def test_op_as_partial():
  153. class OpAsPartial(nn.Cell):
  154. def __init__(self):
  155. super(OpAsPartial, self).__init__()
  156. self.op = FakeOp()
  157. def construct(self, x, y, z):
  158. partial_op = F.partial(self.op, x)
  159. a = partial_op(y)
  160. b = partial_op(z)
  161. return a, b
  162. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  163. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  164. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  165. net = OpAsPartial()
  166. net(t1, t2, t3)
  167. # test op with partial
  168. def test_op_as_partial_inside():
  169. class OpAsPartial(nn.Cell):
  170. def __init__(self):
  171. super(OpAsPartial, self).__init__()
  172. self.op = FakeOp()
  173. def construct(self, x, y, z):
  174. partial_op = F.partial(self.op, x)
  175. a = partial_op(y)
  176. b = partial_op(z)
  177. return a, b
  178. class OuterNet(nn.Cell):
  179. def __init__(self):
  180. super(OuterNet, self).__init__()
  181. self.net = OpAsPartial()
  182. def construct(self, x, y, z):
  183. a, b = self.net(x, y, z)
  184. return a, b
  185. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  186. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  187. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  188. net = OuterNet()
  189. net(t1, t2, t3)
  190. # test op with partial case 2
  191. def test_op_as_partial_independent():
  192. class OpAsPartial(nn.Cell):
  193. def __init__(self):
  194. super(OpAsPartial, self).__init__()
  195. self.op = FakeOp()
  196. def construct(self, x, y, z):
  197. partial_op1 = F.partial(self.op, x)
  198. a = partial_op1(y)
  199. partial_op2 = F.partial(self.op, x)
  200. b = partial_op2(z)
  201. return a, b
  202. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  203. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  204. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  205. net = OpAsPartial()
  206. net(t1, t2, t3)
  207. def test_nest_partial():
  208. class NestPartial(nn.Cell):
  209. def __init__(self):
  210. super(NestPartial, self).__init__()
  211. self.op = FakeOp()
  212. def construct(self, x, y, z):
  213. partial_op1 = F.partial(self.op)
  214. partial_op2 = F.partial(partial_op1, x)
  215. a = partial_op2(y)
  216. partial_op3 = F.partial(self.op)
  217. partial_op4 = F.partial(partial_op3, x)
  218. b = partial_op4(z)
  219. return a, b
  220. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  221. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  222. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  223. net = NestPartial()
  224. net(t1, t2, t3)
  225. # high order argument
  226. # op and op args as network arguments
  227. def test_op_with_arg_as_input():
  228. class WithOpArgNet(nn.Cell):
  229. def __init__(self):
  230. super(WithOpArgNet, self).__init__()
  231. def construct(self, op, x, y):
  232. return op(x, y)
  233. class OpsNet(nn.Cell):
  234. def __init__(self, net):
  235. super(OpsNet, self).__init__()
  236. self.opnet = net
  237. self.op = FakeOp()
  238. def construct(self, x, y, z):
  239. op = self.op
  240. a = self.opnet(op, x, z)
  241. b = self.opnet(op, x, y)
  242. return (a, b)
  243. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  244. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  245. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  246. net = OpsNet(WithOpArgNet())
  247. net(t1, t2, t3)
  248. # The partial application used as argument is not supported yet
  249. # because of the limit of inference specialize system
  250. def test_partial_as_arg():
  251. class PartialArgNet(nn.Cell):
  252. def __init__(self):
  253. super(PartialArgNet, self).__init__()
  254. def construct(self, partial_op, y):
  255. return partial_op(y)
  256. class OpsNet(nn.Cell):
  257. def __init__(self, net):
  258. super(OpsNet, self).__init__()
  259. self.partial_net = net
  260. self.op = FakeOp()
  261. def construct(self, x, y, z):
  262. partial_op = F.partial(self.op, x)
  263. a = self.partial_net(partial_op, z)
  264. b = self.partial_net(partial_op, y)
  265. return (a, b)
  266. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  267. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  268. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  269. net = OpsNet(PartialArgNet())
  270. net(t1, t2, t3)