You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_ops_attr_infer.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test nn ops """
  16. import functools
  17. import numpy as np
  18. import mindspore
  19. import mindspore.nn as nn
  20. import mindspore.context as context
  21. import mindspore.common.dtype as mstype
  22. from mindspore import Tensor, Parameter
  23. from mindspore.common.initializer import initializer
  24. from mindspore.ops import Primitive
  25. from mindspore.ops import composite as C
  26. from mindspore.ops import operations as P
  27. from mindspore.ops import functional as F
  28. from mindspore.ops import prim_attr_register, PrimitiveWithInfer
  29. from mindspore.ops.primitive import constexpr
  30. from ..ut_filter import non_graph_engine
  31. from ....mindspore_test_framework.mindspore_test import mindspore_test
  32. from ....mindspore_test_framework.pipeline.forward.compile_forward \
  33. import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
  34. from ....mindspore_test_framework.pipeline.forward.verify_exception \
  35. import pipeline_for_verify_exception_for_case_by_case_config
  36. from mindspore import context
  37. context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
  38. class FakeOp(PrimitiveWithInfer):
  39. @prim_attr_register
  40. def __init__(self):
  41. """"""
  42. def infer_shape(self, x, y):
  43. self.second_shape = y
  44. self.add_prim_attr("second_shape", y)
  45. return x
  46. def infer_dtype(self, x, y):
  47. return x
  48. # test the normal case that should generate independent primitive because of different
  49. # generated attributes after inference
  50. def test_conv2d_same_primitive():
  51. class Conv2DSameNet(nn.Cell):
  52. def __init__(self):
  53. super(Conv2DSameNet, self).__init__()
  54. self.conv1 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
  55. self.conv2 = nn.Conv2d(16, 64, (1, 41), (1,4), "same", 0, 1, has_bias=True)
  56. def construct(self, x, y):
  57. r1 = self.conv1(x)
  58. r2 = self.conv2(y)
  59. return (r1, r2)
  60. t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
  61. t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
  62. net = Conv2DSameNet()
  63. out = net(t1, t2)
  64. # test cell as high order argument
  65. # The graph with free variables used as argument is not supported yet
  66. # because of the limit of inference specialize system
  67. def Xtest_conv2d_op_with_arg():
  68. class Conv2dNet(nn.Cell):
  69. def __init__(self):
  70. super(Conv2dNet, self).__init__()
  71. def construct(self, op, x):
  72. return op(x)
  73. class OpsNet(nn.Cell):
  74. def __init__(self, net):
  75. super(OpsNet, self).__init__()
  76. self.opnet = net
  77. self.conv2 = nn.Conv2d(16, 64, (1, 41), (1, 4), "same", 0, 1, has_bias=True)
  78. def construct(self, x, y):
  79. conv_op = self.conv2
  80. a = self.opnet(conv_op, x)
  81. b = self.opnet(conv_op, y)
  82. return (a, b)
  83. t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
  84. t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
  85. net = OpsNet(Conv2dNet())
  86. out = net(t1, t2)
  87. def test_conv2d_op_with_arg():
  88. class FackOpNet(nn.Cell):
  89. def __init__(self):
  90. super(FackOpNet, self).__init__()
  91. self.op = FakeOp()
  92. def construct(self, x, y):
  93. return self.op(x, y)
  94. class OpNet(nn.Cell):
  95. def __init__(self):
  96. super(OpNet, self).__init__()
  97. def construct(self, op, x, y):
  98. return op(x, y)
  99. class OpsNet(nn.Cell):
  100. def __init__(self, net):
  101. super(OpsNet, self).__init__()
  102. self.opnet = net
  103. self.op = FackOpNet()
  104. def construct(self, x, y):
  105. op = self.op
  106. a = self.opnet(op, x, y)
  107. b = self.opnet(op, y, x)
  108. return (a, b)
  109. t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
  110. t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
  111. net = OpsNet(OpNet())
  112. out = net(t1, t2)
  113. def test_conv2d_op_with_arg_same_input():
  114. class FackOpNet(nn.Cell):
  115. def __init__(self):
  116. super(FackOpNet, self).__init__()
  117. self.op = FakeOp()
  118. def construct(self, x, y):
  119. return self.op(x, y)
  120. class OpNet(nn.Cell):
  121. def __init__(self):
  122. super(OpNet, self).__init__()
  123. def construct(self, op, x, y):
  124. return op(x, y)
  125. class OpsNet(nn.Cell):
  126. def __init__(self, net):
  127. super(OpsNet, self).__init__()
  128. self.opnet = net
  129. self.op = FackOpNet()
  130. def construct(self, x, y):
  131. op = self.op
  132. a = self.opnet(op, x, x)
  133. b = self.opnet(op, y, x)
  134. return (a, b)
  135. t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
  136. t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
  137. net = OpsNet(OpNet())
  138. out = net(t1, t2)
  139. # test op with partial
  140. def test_op_as_partial():
  141. class OpAsPartial(nn.Cell):
  142. def __init__(self):
  143. super(OpAsPartial, self).__init__()
  144. self.op = FakeOp()
  145. def construct(self, x, y, z):
  146. partial_op = F.partial(self.op, x)
  147. a = partial_op(y)
  148. b = partial_op(z)
  149. return a, b
  150. t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
  151. t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
  152. t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
  153. net = OpAsPartial()
  154. out = net(t1, t2, t3)
  155. # test op with partial
  156. def test_op_as_partial_inside():
  157. class OpAsPartial(nn.Cell):
  158. def __init__(self):
  159. super(OpAsPartial, self).__init__()
  160. self.op = FakeOp()
  161. def construct(self, x, y, z):
  162. partial_op = F.partial(self.op, x)
  163. a = partial_op(y)
  164. b = partial_op(z)
  165. return a, b
  166. class OuterNet(nn.Cell):
  167. def __init__(self):
  168. super(OuterNet, self).__init__()
  169. self.net = OpAsPartial()
  170. def construct(self, x, y, z):
  171. a,b = self.net(x, y, z)
  172. return a, b
  173. t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
  174. t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
  175. t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
  176. net = OuterNet()
  177. out = net(t1, t2, t3)
  178. # test op with partial case 2
  179. def test_op_as_partial_independent():
  180. class OpAsPartial(nn.Cell):
  181. def __init__(self):
  182. super(OpAsPartial, self).__init__()
  183. self.op = FakeOp()
  184. def construct(self, x, y, z):
  185. partial_op1 = F.partial(self.op, x)
  186. a = partial_op1(y)
  187. partial_op2 = F.partial(self.op, x)
  188. b = partial_op2(z)
  189. return a, b
  190. t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
  191. t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
  192. t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
  193. net = OpAsPartial()
  194. out = net(t1, t2, t3)
  195. def test_nest_partial():
  196. class NestPartial(nn.Cell):
  197. def __init__(self):
  198. super(NestPartial, self).__init__()
  199. self.op = FakeOp()
  200. def construct(self, x, y, z):
  201. partial_op1 = F.partial(self.op)
  202. partial_op2 = F.partial(partial_op1, x)
  203. a = partial_op2(y)
  204. partial_op3 = F.partial(self.op)
  205. partial_op4 = F.partial(partial_op3, x)
  206. b = partial_op4(z)
  207. return a, b
  208. t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
  209. t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
  210. t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
  211. net = NestPartial()
  212. out = net(t1, t2, t3)
  213. # high order argument
  214. # op and op args as network arguments
  215. def test_op_with_arg_as_input():
  216. class WithOpArgNet(nn.Cell):
  217. def __init__(self):
  218. super(WithOpArgNet, self).__init__()
  219. def construct(self, op, x, y):
  220. return op(x, y)
  221. class OpsNet(nn.Cell):
  222. def __init__(self, net):
  223. super(OpsNet, self).__init__()
  224. self.opnet = net
  225. self.op = FakeOp()
  226. def construct(self, x, y, z):
  227. op = self.op
  228. a = self.opnet(op, x, z)
  229. b = self.opnet(op, x, y)
  230. return (a, b)
  231. t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
  232. t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
  233. t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
  234. net = OpsNet(WithOpArgNet())
  235. out = net(t1, t2, t3)
  236. # The partial application used as argument is not supported yet
  237. # because of the limit of inference specialize system
  238. def Xtest_partial_as_arg():
  239. class PartialArgNet(nn.Cell):
  240. def __init__(self):
  241. super(PartialArgNet, self).__init__()
  242. def construct(self, partial_op, y):
  243. return partial_op(y)
  244. class OpsNet(nn.Cell):
  245. def __init__(self, net):
  246. super(OpsNet, self).__init__()
  247. self.partial_net = net
  248. self.op = FakeOp()
  249. def construct(self, x, y, z):
  250. partial_op = F.partial(self.op, x)
  251. a = self.partial_net(partial_op, z)
  252. b = self.partial_net(partial_op, y)
  253. return (a, b)
  254. t1 = Tensor(np.ones([1,16,1,1918]).astype(np.float32))
  255. t2 = Tensor(np.ones([1,16,1,3840]).astype(np.float32))
  256. t3 = Tensor(np.ones([1,16,1,1234]).astype(np.float32))
  257. net = OpsNet(PartialArgNet())
  258. out = net(t1, t2, t3)