You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_ops_attr_infer.py 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test nn ops """
  16. import numpy as np
  17. from numpy.random import normal
  18. import pytest
  19. import mindspore.nn as nn
  20. import mindspore.context as context
  21. from mindspore.ops.composite import core
  22. from mindspore.common.api import ms_function
  23. from mindspore import Tensor
  24. from mindspore.ops import functional as F
  25. from mindspore.ops import prim_attr_register, PrimitiveWithInfer
  26. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  27. class FakeOp(PrimitiveWithInfer):
  28. @prim_attr_register
  29. def __init__(self):
  30. """"""
  31. def infer_shape(self, x, y):
  32. self.second_shape = y
  33. self.add_prim_attr("second_shape", y)
  34. return x
  35. def infer_dtype(self, x, y):
  36. return x
  37. # test the normal case that should generate independent primitive because of different
  38. # generated attributes after inference
  39. def test_conv2d_same_primitive():
  40. class Conv2DSameNet(nn.Cell):
  41. def __init__(self):
  42. super(Conv2DSameNet, self).__init__()
  43. self.conv1 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
  44. self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
  45. def construct(self, x, y):
  46. r1 = self.conv1(x)
  47. r2 = self.conv2(y)
  48. return (r1, r2)
  49. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  50. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  51. net = Conv2DSameNet()
  52. net(t1, t2)
  53. # test free variable function list as parameter
  54. def test_remove_and_fv_2():
  55. @core(loop_can_uroll=True)
  56. def inner_loop(x, input_data, fv_func_list):
  57. ret = ()
  58. for fv_fn in fv_func_list:
  59. ele = fv_fn(input_data)
  60. ret += (ele,)
  61. return ret
  62. @ms_function
  63. def out_loop(input1, input_data):
  64. ret = ()
  65. def fv_func1(y):
  66. return input1 * y
  67. def fv_func2(y):
  68. return input1 - y
  69. fv_func_list = [fv_func1, fv_func2]
  70. ele0 = inner_loop(input1, input_data[0], fv_func_list)
  71. ele1 = inner_loop(input1, input_data[1], fv_func_list)
  72. ret = (ele0, ele1)
  73. return ret
  74. input_data = (Tensor(normal(0, 0.1, (3, 3))), Tensor(normal(0, 0.1, (3, 1))))
  75. input1 = Tensor(normal(0, 0.1, (3, 3)))
  76. out_loop(input1, input_data)
  77. # test cell as high order argument
  78. # The graph with free variables used as argument is not supported yet
  79. # because of the limit of inference specialize system
  80. def test_conv2d_op_with_argi_1():
  81. class Conv2dNet(nn.Cell):
  82. def __init__(self):
  83. super(Conv2dNet, self).__init__()
  84. def construct(self, op, x):
  85. return op(x)
  86. class OpsNet(nn.Cell):
  87. def __init__(self, net):
  88. super(OpsNet, self).__init__()
  89. self.opnet = net
  90. self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
  91. def construct(self, x, y):
  92. conv_op = self.conv2
  93. a = self.opnet(conv_op, x)
  94. b = self.opnet(conv_op, y)
  95. return (a, b)
  96. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  97. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  98. net = OpsNet(Conv2dNet())
  99. net(t1, t2)
  100. def test_conv2d_op_with_arg():
  101. class FackOpNet(nn.Cell):
  102. def __init__(self):
  103. super(FackOpNet, self).__init__()
  104. self.op = FakeOp()
  105. def construct(self, x, y):
  106. return self.op(x, y)
  107. class OpNet(nn.Cell):
  108. def __init__(self):
  109. super(OpNet, self).__init__()
  110. def construct(self, op, x, y):
  111. return op(x, y)
  112. class OpsNet(nn.Cell):
  113. def __init__(self, net):
  114. super(OpsNet, self).__init__()
  115. self.opnet = net
  116. self.op = FackOpNet()
  117. def construct(self, x, y):
  118. op = self.op
  119. a = self.opnet(op, x, y)
  120. b = self.opnet(op, y, x)
  121. return (a, b)
  122. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  123. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  124. net = OpsNet(OpNet())
  125. net(t1, t2)
  126. def test_conv2d_op_with_arg_same_input():
  127. class FackOpNet(nn.Cell):
  128. def __init__(self):
  129. super(FackOpNet, self).__init__()
  130. self.op = FakeOp()
  131. def construct(self, x, y):
  132. return self.op(x, y)
  133. class OpNet(nn.Cell):
  134. def __init__(self):
  135. super(OpNet, self).__init__()
  136. def construct(self, op, x, y):
  137. return op(x, y)
  138. class OpsNet(nn.Cell):
  139. def __init__(self, net):
  140. super(OpsNet, self).__init__()
  141. self.opnet = net
  142. self.op = FackOpNet()
  143. def construct(self, x, y):
  144. op = self.op
  145. a = self.opnet(op, x, x)
  146. b = self.opnet(op, y, x)
  147. return (a, b)
  148. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  149. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  150. net = OpsNet(OpNet())
  151. net(t1, t2)
  152. # test op with partial
  153. def test_op_as_partial():
  154. class OpAsPartial(nn.Cell):
  155. def __init__(self):
  156. super(OpAsPartial, self).__init__()
  157. self.op = FakeOp()
  158. def construct(self, x, y, z):
  159. partial_op = F.partial(self.op, x)
  160. a = partial_op(y)
  161. b = partial_op(z)
  162. return a, b
  163. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  164. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  165. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  166. net = OpAsPartial()
  167. net(t1, t2, t3)
  168. # test op with partial
  169. def test_op_as_partial_inside():
  170. class OpAsPartial(nn.Cell):
  171. def __init__(self):
  172. super(OpAsPartial, self).__init__()
  173. self.op = FakeOp()
  174. def construct(self, x, y, z):
  175. partial_op = F.partial(self.op, x)
  176. a = partial_op(y)
  177. b = partial_op(z)
  178. return a, b
  179. class OuterNet(nn.Cell):
  180. def __init__(self):
  181. super(OuterNet, self).__init__()
  182. self.net = OpAsPartial()
  183. def construct(self, x, y, z):
  184. a, b = self.net(x, y, z)
  185. return a, b
  186. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  187. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  188. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  189. net = OuterNet()
  190. net(t1, t2, t3)
  191. # test op with partial case 2
  192. def test_op_as_partial_independent():
  193. class OpAsPartial(nn.Cell):
  194. def __init__(self):
  195. super(OpAsPartial, self).__init__()
  196. self.op = FakeOp()
  197. def construct(self, x, y, z):
  198. partial_op1 = F.partial(self.op, x)
  199. a = partial_op1(y)
  200. partial_op2 = F.partial(self.op, x)
  201. b = partial_op2(z)
  202. return a, b
  203. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  204. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  205. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  206. net = OpAsPartial()
  207. net(t1, t2, t3)
  208. def test_nest_partial():
  209. class NestPartial(nn.Cell):
  210. def __init__(self):
  211. super(NestPartial, self).__init__()
  212. self.op = FakeOp()
  213. def construct(self, x, y, z):
  214. partial_op1 = F.partial(self.op)
  215. partial_op2 = F.partial(partial_op1, x)
  216. a = partial_op2(y)
  217. partial_op3 = F.partial(self.op)
  218. partial_op4 = F.partial(partial_op3, x)
  219. b = partial_op4(z)
  220. return a, b
  221. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  222. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  223. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  224. net = NestPartial()
  225. net(t1, t2, t3)
  226. # high order argument
  227. # op and op args as network arguments
  228. def test_op_with_arg_as_input():
  229. class WithOpArgNet(nn.Cell):
  230. def __init__(self):
  231. super(WithOpArgNet, self).__init__()
  232. def construct(self, op, x, y):
  233. return op(x, y)
  234. class OpsNet(nn.Cell):
  235. def __init__(self, net):
  236. super(OpsNet, self).__init__()
  237. self.opnet = net
  238. self.op = FakeOp()
  239. def construct(self, x, y, z):
  240. op = self.op
  241. a = self.opnet(op, x, z)
  242. b = self.opnet(op, x, y)
  243. return (a, b)
  244. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  245. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  246. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  247. net = OpsNet(WithOpArgNet())
  248. net(t1, t2, t3)
  249. # The partial application used as argument is not supported yet
  250. # because of the limit of inference specialize system
  251. @pytest.mark.skip("poly in infer")
  252. def test_partial_as_arg():
  253. class PartialArgNet(nn.Cell):
  254. def __init__(self):
  255. super(PartialArgNet, self).__init__()
  256. def construct(self, partial_op, y):
  257. return partial_op(y)
  258. class OpsNet(nn.Cell):
  259. def __init__(self, net):
  260. super(OpsNet, self).__init__()
  261. self.partial_net = net
  262. self.op = FakeOp()
  263. def construct(self, x, y, z):
  264. partial_op = F.partial(self.op, x)
  265. a = self.partial_net(partial_op, z)
  266. b = self.partial_net(partial_op, y)
  267. return (a, b)
  268. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  269. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  270. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  271. net = OpsNet(PartialArgNet())
  272. net(t1, t2, t3)