You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parse_method.py 9.2 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. @File : test_parse_method.py
  17. @Author:
  18. @Date : 2019-06-27
  19. @Desc : test parse the object's method
  20. """
  21. import logging
  22. from dataclasses import dataclass
  23. import numpy as np
  24. import pytest
  25. import mindspore.nn as nn
  26. from mindspore import context
  27. from mindspore._extends.parse.standard_method import ms_len
  28. from mindspore.common.api import ms_function
  29. from mindspore.common.tensor import Tensor
  30. from mindspore.ops.composite import core
  31. from mindspore.ops.primitive import constexpr
  32. from ..ut_filter import non_graph_engine
  33. def setup_module(module):
  34. context.set_context(mode=context.PYNATIVE_MODE)
  35. log = logging.getLogger("test")
  36. log.setLevel(level=logging.ERROR)
  37. @ms_function
  38. def default_parameter_f(x, y=3):
  39. """ default_parameter_f """
  40. z = x + y
  41. return z
  42. # Test case: test parse fn that use default parameter
  43. def test_parse_defalut_parameter_case1():
  44. """ Test default parameter function call """
  45. log.debug("begin test_parse_defalut_parameter_case1")
  46. ret = default_parameter_f(2)
  47. log.debug("finished test_parse_defalut_parameter_case1, ret = %r", ret)
  48. def get_val_fn(x):
  49. """ get_val_fn """
  50. ret = x + 3
  51. return ret
  52. # Test case: test bool not
  53. @ms_function
  54. def bool_exp(x, y):
  55. """ bool_exp """
  56. return not x > y
  57. def test_bool_exp():
  58. """ test_bool_exp """
  59. bool_exp(1, 2)
  60. # Test case: use the variable parameter for @mindspore
  61. @ms_function
  62. def var_parameter_f(x, *args):
  63. """ var_parameter_f """
  64. z = x + args[0] + args[1] + args[2]
  65. return z
  66. def test_var_parameter_case1():
  67. """ test_var_parameter_case1 """
  68. log.debug("start test_var_parameter_case1")
  69. var_parameter_f(1, 2, 3, 4, 5)
  70. log.debug("end test_var_parameter_case1")
  71. class Net(nn.Cell):
  72. """ Net definition """
  73. def __init__(self, value1):
  74. super(Net, self).__init__()
  75. self.relu = nn.ReLU()
  76. self.softmax = nn.Softmax(0)
  77. self.axis = 0
  78. self.TC = ClassTest("test_class", 1.2)
  79. self.value = value1
  80. @ms_function
  81. def construct(self, x):
  82. x = self.get_test_value(x)
  83. return x
  84. def get_test_value(self, x):
  85. ret = x + self.value
  86. return ret
  87. class ClassTest:
  88. """ ClassTest definition """
  89. def __init__(self, name, value1):
  90. self.name = name
  91. self.value = value1
  92. def get_name(self):
  93. return self.name
  94. def get_value(self, inc):
  95. ret = self.value + inc
  96. return ret
  97. def __call__(self, *args, **kwargs):
  98. pass
  99. # Test: call method on parse graph code
  100. @non_graph_engine
  101. def test_call_method_on_construct():
  102. """ test_call_method_on_construct """
  103. log.debug("begin test_call_method_on_construct")
  104. x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
  105. y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
  106. z = np.array([[3, 5, 7], [2, 3, 5]]).astype(np.int32)
  107. net = Net(y)
  108. output = net.construct(x)
  109. result = output.asnumpy()
  110. print(result)
  111. assert np.all(result == z)
  112. log.debug("finished test_call_method_on_construct")
  113. # Test: call method on parse graph code
  114. class Net1(nn.Cell):
  115. """ Net1 definition """
  116. def __init__(self, v1, v2):
  117. super(Net1, self).__init__()
  118. self.relu = nn.ReLU()
  119. self.softmax = nn.Softmax(0)
  120. self.axis = 0
  121. self.TC = ClassTest("test_class", v1)
  122. self.value = v2
  123. @ms_function
  124. def construct(self, x):
  125. x = x + self.TC.get_value(self.value)
  126. return x
  127. @non_graph_engine
  128. def test_call_other_object_method():
  129. """ test_call_other_object_method """
  130. log.debug("begin test_call_other_object_method")
  131. x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
  132. y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
  133. y1 = Tensor(np.array([[5, 4, 5], [1, 1, 2]]).astype(np.int32))
  134. z = np.array([[8, 9, 12], [3, 4, 7]]).astype(np.int32)
  135. net = Net1(y, y1)
  136. with pytest.raises(TypeError):
  137. output = net.construct(x)
  138. result = output.asnumpy()
  139. print(result)
  140. assert np.all(result == z)
  141. log.debug("finished test_call_other_object_method")
  142. # Test: call global object method(not self) on parse graph code
  143. value = Tensor(np.array([[3, 4, 5], [1, 1, 2]]).astype(np.int32))
  144. TC = ClassTest("test_class", value)
  145. class Net2(nn.Cell):
  146. """ Net2 definition """
  147. def __init__(self, value1):
  148. super(Net2, self).__init__()
  149. self.value = value1
  150. @ms_function
  151. def construct(self, x):
  152. x = x + TC.get_value(self.value)
  153. return x
  154. @ms_function
  155. def construct1(self, x):
  156. x = x + TC.value
  157. x = x + self.value
  158. return x
  159. @non_graph_engine
  160. def test_call_no_self_other_object_method():
  161. """ test_call_no_self_other_object_method """
  162. log.debug("begin test_call_other_object_method")
  163. x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
  164. y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
  165. z = np.array([[6, 9, 12], [3, 4, 7]]).astype(np.int32)
  166. net = Net2(y)
  167. with pytest.raises(TypeError):
  168. output = net.construct(x)
  169. result = output.asnumpy()
  170. print(result)
  171. assert np.all(result == z)
  172. log.debug("finished test_call_other_object_method")
  173. def test_call_no_self_other_object_attr_value():
  174. """ test_call_no_self_other_object_attr_value """
  175. # do not support tensor as init input.
  176. return
  177. # Test case: use the * to unlock the varargs for @mindspore
  178. def vararg1(x, y):
  179. """ vararg1 """
  180. z = x + y
  181. return z
  182. def varargs_main(fn):
  183. """ varargs_main """
  184. @ms_function
  185. def t1(*args):
  186. return fn(*args)
  187. return t1
  188. def test_var_parameter_case3():
  189. """ test_var_parameter_case3 """
  190. log.debug("start test_var_parameter_case3")
  191. ret = varargs_main(vararg1)(1, 2)
  192. log.debug("ret = %r", ret)
  193. log.debug("end test_var_parameter_case3")
  194. # Test case: test the flag set
  195. @core(tg=True)
  196. def set_flag(x):
  197. """ set_flag """
  198. return x + 1
  199. @ms_function
  200. def set_test_flag_main(x, y):
  201. """ set_test_flag_main """
  202. z = set_flag(x)
  203. z = z + y
  204. return z
  205. def test_set_flag():
  206. """ Test default parameter function call """
  207. log.debug("begin test_set_flag")
  208. ret = set_test_flag_main(2, 3)
  209. log.debug("finished test_set_flag, ret = %r", ret)
  210. @dataclass
  211. class Access:
  212. a: int
  213. b: int
  214. def max(self):
  215. if self.a > self.b:
  216. return self.a
  217. return self.b
  218. @ms_function
  219. def invoke_dataclass(x, y):
  220. """ invoke_dataclass """
  221. acs = Access(x, y)
  222. return acs.max()
  223. def test_access():
  224. """ test_access """
  225. invoke_dataclass(1, 2)
  226. def myfunc(x):
  227. """ myfunc """
  228. return x * x
  229. @ms_function
  230. def ms_infer_for():
  231. """ ms_infer_for """
  232. a = 0.0
  233. for x in [1.1, 2.3, 3.3]:
  234. a = a + x
  235. return a
  236. def test_infer_for():
  237. """ test_infer_for """
  238. ms_infer_for()
  239. @ms_function
  240. def ms_infer_for_func(y):
  241. """ ms_infer_for_func """
  242. for x in [1.0, 2.0, 3.0]:
  243. y = myfunc(x) + y
  244. return y
  245. def test_ms_infer_for_func():
  246. """ test_ms_infer_for_func """
  247. ms_infer_for_func(1.0)
  248. @ms_function
  249. def add(x, y):
  250. """ add """
  251. return x + y
  252. def test_add():
  253. """ test_add """
  254. res = add(1, 2.0)
  255. return res
  256. @ms_function
  257. def add_list():
  258. """ add_list """
  259. a = [1, 2, 3]
  260. b = a[1] + a[2]
  261. return b
  262. def test_list():
  263. """ test_list """
  264. return add_list()
  265. @ms_function
  266. def compare_list_len():
  267. """ compare_list_len """
  268. a = [1, 2, 3]
  269. return ms_len(a)
  270. def test_list_len():
  271. """ test_list_len """
  272. compare_list_len()
  273. @ms_function
  274. def add_tuple():
  275. """ add_tuple """
  276. a = (1, 2, 3)
  277. b = a[1] + a[2]
  278. return b
  279. def test_tuple():
  280. """ test_tuple """
  281. return add_tuple()
  282. def invoke_func(x):
  283. """ invoke_func """
  284. return x * x
  285. @ms_function
  286. def tuple_of_node(x, y):
  287. """ tuple_of_node """
  288. a = invoke_func(x)
  289. b = invoke_func(y)
  290. c = (a, b)
  291. d = c[1] * x
  292. return d
  293. def test_tuple_node():
  294. """ test_tuple_node """
  295. res = tuple_of_node(1, 2)
  296. return res
  297. @ms_function
  298. def range_spec(x, y):
  299. """ range_spec """
  300. for _ in range(1, 10, 3):
  301. x = x + 1
  302. return x + y
  303. def test_range():
  304. """ test_range """
  305. res = range_spec(10, 10)
  306. return res
  307. def test_expr():
  308. """ test const expr """
  309. a = (1, 2)
  310. @constexpr
  311. def tuple_len(x):
  312. assert len(x) == 2
  313. tuple_len(a)