You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_ops_attr_infer.py 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test nn ops """
  16. import numpy as np
  17. from numpy.random import normal
  18. import pytest
  19. import mindspore.nn as nn
  20. import mindspore.context as context
  21. from mindspore.ops.composite import core
  22. from mindspore.common.api import ms_function
  23. from mindspore import Tensor
  24. from mindspore.ops import functional as F
  25. from mindspore.ops import prim_attr_register, PrimitiveWithInfer
  26. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  27. class FakeOp(PrimitiveWithInfer):
  28. @prim_attr_register
  29. def __init__(self):
  30. """"""
  31. def infer_shape(self, x, y):
  32. self.second_shape = y
  33. self.add_prim_attr("second_shape", y)
  34. return x
  35. def infer_dtype(self, x, y):
  36. return x
  37. # test the normal case that should generate independent primitive because of different
  38. # generated attributes after inference
  39. def test_conv2d_same_primitive():
  40. class Conv2DSameNet(nn.Cell):
  41. def __init__(self):
  42. super(Conv2DSameNet, self).__init__()
  43. self.conv1 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
  44. self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
  45. def construct(self, x, y):
  46. r1 = self.conv1(x)
  47. r2 = self.conv2(y)
  48. return (r1, r2)
  49. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  50. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  51. net = Conv2DSameNet()
  52. net(t1, t2)
  53. # test free variable function list as parameter
  54. def test_remove_and_fv_2():
  55. @core(loop_can_uroll=True)
  56. def inner_loop(x, input_data, fv_func_list):
  57. ret = ()
  58. for fv_fn in fv_func_list:
  59. ele = fv_fn(input_data)
  60. ret += (ele,)
  61. return ret
  62. @ms_function
  63. def out_loop(input1, input_data0, input_data1):
  64. ret = ()
  65. def fv_func1(y):
  66. return input1 * y
  67. def fv_func2(y):
  68. return input1 - y
  69. fv_func_list = [fv_func1, fv_func2]
  70. ele0 = inner_loop(input1, input_data0, fv_func_list)
  71. ele1 = inner_loop(input1, input_data1, fv_func_list)
  72. ret = (ele0, ele1)
  73. return ret
  74. input_data0 = Tensor(normal(0, 0.1, (3, 3)))
  75. input_data1 = Tensor(normal(0, 0.1, (3, 1)))
  76. input1 = Tensor(normal(0, 0.1, (3, 3)))
  77. out_loop(input1, input_data0, input_data1)
  78. # test cell as high order argument
  79. # The graph with free variables used as argument is not supported yet
  80. # because of the limit of inference specialize system
  81. def test_conv2d_op_with_argi_1():
  82. class Conv2dNet(nn.Cell):
  83. def __init__(self):
  84. super(Conv2dNet, self).__init__()
  85. def construct(self, op, x):
  86. return op(x)
  87. class OpsNet(nn.Cell):
  88. def __init__(self, net):
  89. super(OpsNet, self).__init__()
  90. self.opnet = net
  91. self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
  92. def construct(self, x, y):
  93. conv_op = self.conv2
  94. a = self.opnet(conv_op, x)
  95. b = self.opnet(conv_op, y)
  96. return (a, b)
  97. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  98. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  99. net = OpsNet(Conv2dNet())
  100. net(t1, t2)
  101. def test_conv2d_op_with_arg():
  102. class FackOpNet(nn.Cell):
  103. def __init__(self):
  104. super(FackOpNet, self).__init__()
  105. self.op = FakeOp()
  106. def construct(self, x, y):
  107. return self.op(x, y)
  108. class OpNet(nn.Cell):
  109. def __init__(self):
  110. super(OpNet, self).__init__()
  111. def construct(self, op, x, y):
  112. return op(x, y)
  113. class OpsNet(nn.Cell):
  114. def __init__(self, net):
  115. super(OpsNet, self).__init__()
  116. self.opnet = net
  117. self.op = FackOpNet()
  118. def construct(self, x, y):
  119. op = self.op
  120. a = self.opnet(op, x, y)
  121. b = self.opnet(op, y, x)
  122. return (a, b)
  123. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  124. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  125. net = OpsNet(OpNet())
  126. net(t1, t2)
  127. def test_conv2d_op_with_arg_same_input():
  128. class FackOpNet(nn.Cell):
  129. def __init__(self):
  130. super(FackOpNet, self).__init__()
  131. self.op = FakeOp()
  132. def construct(self, x, y):
  133. return self.op(x, y)
  134. class OpNet(nn.Cell):
  135. def __init__(self):
  136. super(OpNet, self).__init__()
  137. def construct(self, op, x, y):
  138. return op(x, y)
  139. class OpsNet(nn.Cell):
  140. def __init__(self, net):
  141. super(OpsNet, self).__init__()
  142. self.opnet = net
  143. self.op = FackOpNet()
  144. def construct(self, x, y):
  145. op = self.op
  146. a = self.opnet(op, x, x)
  147. b = self.opnet(op, y, x)
  148. return (a, b)
  149. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  150. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  151. net = OpsNet(OpNet())
  152. net(t1, t2)
  153. # test op with partial
  154. def test_op_as_partial():
  155. class OpAsPartial(nn.Cell):
  156. def __init__(self):
  157. super(OpAsPartial, self).__init__()
  158. self.op = FakeOp()
  159. def construct(self, x, y, z):
  160. partial_op = F.partial(self.op, x)
  161. a = partial_op(y)
  162. b = partial_op(z)
  163. return a, b
  164. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  165. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  166. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  167. net = OpAsPartial()
  168. net(t1, t2, t3)
  169. # test op with partial
  170. def test_op_as_partial_inside():
  171. class OpAsPartial(nn.Cell):
  172. def __init__(self):
  173. super(OpAsPartial, self).__init__()
  174. self.op = FakeOp()
  175. def construct(self, x, y, z):
  176. partial_op = F.partial(self.op, x)
  177. a = partial_op(y)
  178. b = partial_op(z)
  179. return a, b
  180. class OuterNet(nn.Cell):
  181. def __init__(self):
  182. super(OuterNet, self).__init__()
  183. self.net = OpAsPartial()
  184. def construct(self, x, y, z):
  185. a, b = self.net(x, y, z)
  186. return a, b
  187. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  188. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  189. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  190. net = OuterNet()
  191. net(t1, t2, t3)
  192. # test op with partial case 2
  193. def test_op_as_partial_independent():
  194. class OpAsPartial(nn.Cell):
  195. def __init__(self):
  196. super(OpAsPartial, self).__init__()
  197. self.op = FakeOp()
  198. def construct(self, x, y, z):
  199. partial_op1 = F.partial(self.op, x)
  200. a = partial_op1(y)
  201. partial_op2 = F.partial(self.op, x)
  202. b = partial_op2(z)
  203. return a, b
  204. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  205. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  206. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  207. net = OpAsPartial()
  208. net(t1, t2, t3)
  209. def test_nest_partial():
  210. class NestPartial(nn.Cell):
  211. def __init__(self):
  212. super(NestPartial, self).__init__()
  213. self.op = FakeOp()
  214. def construct(self, x, y, z):
  215. partial_op1 = F.partial(self.op)
  216. partial_op2 = F.partial(partial_op1, x)
  217. a = partial_op2(y)
  218. partial_op3 = F.partial(self.op)
  219. partial_op4 = F.partial(partial_op3, x)
  220. b = partial_op4(z)
  221. return a, b
  222. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  223. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  224. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  225. net = NestPartial()
  226. net(t1, t2, t3)
  227. # high order argument
  228. # op and op args as network arguments
  229. def test_op_with_arg_as_input():
  230. class WithOpArgNet(nn.Cell):
  231. def __init__(self):
  232. super(WithOpArgNet, self).__init__()
  233. def construct(self, op, x, y):
  234. return op(x, y)
  235. class OpsNet(nn.Cell):
  236. def __init__(self, net):
  237. super(OpsNet, self).__init__()
  238. self.opnet = net
  239. self.op = FakeOp()
  240. def construct(self, x, y, z):
  241. op = self.op
  242. a = self.opnet(op, x, z)
  243. b = self.opnet(op, x, y)
  244. return (a, b)
  245. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  246. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  247. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  248. net = OpsNet(WithOpArgNet())
  249. net(t1, t2, t3)
  250. # The partial application used as argument is not supported yet
  251. # because of the limit of inference specialize system
  252. @pytest.mark.skip("poly in infer")
  253. def test_partial_as_arg():
  254. class PartialArgNet(nn.Cell):
  255. def __init__(self):
  256. super(PartialArgNet, self).__init__()
  257. def construct(self, partial_op, y):
  258. return partial_op(y)
  259. class OpsNet(nn.Cell):
  260. def __init__(self, net):
  261. super(OpsNet, self).__init__()
  262. self.partial_net = net
  263. self.op = FakeOp()
  264. def construct(self, x, y, z):
  265. partial_op = F.partial(self.op, x)
  266. a = self.partial_net(partial_op, z)
  267. b = self.partial_net(partial_op, y)
  268. return (a, b)
  269. t1 = Tensor(np.ones([1, 16, 1, 1918]).astype(np.float32))
  270. t2 = Tensor(np.ones([1, 16, 1, 3840]).astype(np.float32))
  271. t3 = Tensor(np.ones([1, 16, 1, 1234]).astype(np.float32))
  272. net = OpsNet(PartialArgNet())
  273. net(t1, t2, t3)