You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parse_method.py 9.8 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. @File : test_parse_method.py
  17. @Author:
  18. @Date : 2019-06-27
  19. @Desc : test parse the object's method
  20. """
  21. import logging
  22. from dataclasses import dataclass
  23. import numpy as np
  24. import pytest
  25. import mindspore.nn as nn
  26. from mindspore import context
  27. from mindspore._extends.parse.standard_method import ms_len
  28. from mindspore.common.api import ms_function
  29. from mindspore.common.tensor import Tensor
  30. from mindspore.ops.composite import core
  31. from mindspore.ops.primitive import constexpr
  32. from mindspore.ops import functional as F
  33. from ..ut_filter import non_graph_engine
  34. def setup_module(module):
  35. context.set_context(mode=context.PYNATIVE_MODE)
  36. log = logging.getLogger("test")
  37. log.setLevel(level=logging.ERROR)
  38. @ms_function
  39. def default_parameter_f(x, y=3):
  40. """ default_parameter_f """
  41. z = x + y
  42. return z
  43. # Test case: test parse fn that use default parameter
  44. def test_parse_defalut_parameter_case1():
  45. """ Test default parameter function call """
  46. log.debug("begin test_parse_defalut_parameter_case1")
  47. ret = default_parameter_f(2)
  48. log.debug("finished test_parse_defalut_parameter_case1, ret = %r", ret)
  49. def get_val_fn(x):
  50. """ get_val_fn """
  51. ret = x + 3
  52. return ret
  53. # Test case: test bool not
  54. @ms_function
  55. def bool_exp(x, y):
  56. """ bool_exp """
  57. return not x > y
  58. def test_bool_exp():
  59. """ test_bool_exp """
  60. bool_exp(1, 2)
  61. # Test case: use the variable parameter for @mindspore
  62. @ms_function
  63. def var_parameter_f(x, *args):
  64. """ var_parameter_f """
  65. z = x + args[0] + args[1] + args[2]
  66. return z
  67. def test_var_parameter_case1():
  68. """ test_var_parameter_case1 """
  69. log.debug("start test_var_parameter_case1")
  70. var_parameter_f(1, 2, 3, 4, 5)
  71. log.debug("end test_var_parameter_case1")
  72. class Net(nn.Cell):
  73. """ Net definition """
  74. def __init__(self, value1):
  75. super(Net, self).__init__()
  76. self.relu = nn.ReLU()
  77. self.softmax = nn.Softmax(0)
  78. self.axis = 0
  79. self.TC = ClassTest("test_class", 1.2)
  80. self.value = value1
  81. @ms_function
  82. def construct(self, x):
  83. x = self.get_test_value(x)
  84. return x
  85. def get_test_value(self, x):
  86. ret = x + self.value
  87. return ret
  88. class ClassTest:
  89. """ ClassTest definition """
  90. def __init__(self, name, value1):
  91. self.name = name
  92. self.value = value1
  93. def get_name(self):
  94. return self.name
  95. def get_value(self, inc):
  96. ret = self.value + inc
  97. return ret
  98. def __call__(self, *args, **kwargs):
  99. pass
  100. # Test: call method on parse graph code
  101. @non_graph_engine
  102. def test_call_method_on_construct():
  103. """ test_call_method_on_construct """
  104. log.debug("begin test_call_method_on_construct")
  105. x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
  106. y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
  107. z = np.array([[3, 5, 7], [2, 3, 5]]).astype(np.int32)
  108. net = Net(y)
  109. output = net.construct(x)
  110. result = output.asnumpy()
  111. print(result)
  112. assert np.all(result == z)
  113. log.debug("finished test_call_method_on_construct")
  114. # Test: call method on parse graph code
  115. class Net1(nn.Cell):
  116. """ Net1 definition """
  117. def __init__(self, v1, v2):
  118. super(Net1, self).__init__()
  119. self.relu = nn.ReLU()
  120. self.softmax = nn.Softmax(0)
  121. self.axis = 0
  122. self.TC = ClassTest("test_class", v1)
  123. self.value = v2
  124. @ms_function
  125. def construct(self, x):
  126. x = x + self.TC.get_value(self.value)
  127. return x
  128. @non_graph_engine
  129. def test_call_other_object_method():
  130. """ test_call_other_object_method """
  131. log.debug("begin test_call_other_object_method")
  132. x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
  133. y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
  134. y1 = Tensor(np.array([[5, 4, 5], [1, 1, 2]]).astype(np.int32))
  135. z = np.array([[8, 9, 12], [3, 4, 7]]).astype(np.int32)
  136. net = Net1(y, y1)
  137. with pytest.raises(TypeError):
  138. output = net.construct(x)
  139. result = output.asnumpy()
  140. print(result)
  141. assert np.all(result == z)
  142. log.debug("finished test_call_other_object_method")
  143. # Test: call global object method(not self) on parse graph code
  144. value = Tensor(np.array([[3, 4, 5], [1, 1, 2]]).astype(np.int32))
  145. TC = ClassTest("test_class", value)
  146. class Net2(nn.Cell):
  147. """ Net2 definition """
  148. def __init__(self, value1):
  149. super(Net2, self).__init__()
  150. self.value = value1
  151. @ms_function
  152. def construct(self, x):
  153. x = x + TC.get_value(self.value)
  154. return x
  155. @ms_function
  156. def construct1(self, x):
  157. x = x + TC.value
  158. x = x + self.value
  159. return x
  160. @non_graph_engine
  161. def test_call_no_self_other_object_method():
  162. """ test_call_no_self_other_object_method """
  163. log.debug("begin test_call_other_object_method")
  164. x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
  165. y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
  166. z = np.array([[6, 9, 12], [3, 4, 7]]).astype(np.int32)
  167. net = Net2(y)
  168. with pytest.raises(TypeError):
  169. output = net.construct(x)
  170. result = output.asnumpy()
  171. print(result)
  172. assert np.all(result == z)
  173. log.debug("finished test_call_other_object_method")
  174. def test_call_no_self_other_object_attr_value():
  175. """ test_call_no_self_other_object_attr_value """
  176. # do not support tensor as init input.
  177. return
  178. # Test case: use the * to unlock the varargs for @mindspore
  179. def vararg1(x, y):
  180. """ vararg1 """
  181. z = x + y
  182. return z
  183. def varargs_main(fn):
  184. """ varargs_main """
  185. @ms_function
  186. def t1(*args):
  187. return fn(*args)
  188. return t1
  189. def test_var_parameter_case3():
  190. """ test_var_parameter_case3 """
  191. log.debug("start test_var_parameter_case3")
  192. ret = varargs_main(vararg1)(1, 2)
  193. log.debug("ret = %r", ret)
  194. log.debug("end test_var_parameter_case3")
  195. # Test case: test the flag set
  196. @core(tg=True)
  197. def set_flag(x):
  198. """ set_flag """
  199. return x + 1
  200. @ms_function
  201. def set_test_flag_main(x, y):
  202. """ set_test_flag_main """
  203. z = set_flag(x)
  204. z = z + y
  205. return z
  206. def test_set_flag():
  207. """ Test default parameter function call """
  208. log.debug("begin test_set_flag")
  209. ret = set_test_flag_main(2, 3)
  210. log.debug("finished test_set_flag, ret = %r", ret)
  211. @dataclass
  212. class Access:
  213. a: int
  214. b: int
  215. def max(self):
  216. if self.a > self.b:
  217. return self.a
  218. return self.b
  219. @ms_function
  220. def invoke_dataclass(x, y):
  221. """ invoke_dataclass """
  222. acs = Access(x, y)
  223. return acs.max()
  224. def test_access():
  225. """ test_access """
  226. invoke_dataclass(1, 2)
  227. @dataclass
  228. class Access2:
  229. a: int
  230. b: int
  231. def max(self):
  232. if self.a > self.b:
  233. return self.c
  234. return self.b
  235. @ms_function
  236. def invoke_dataclass2(x, y):
  237. """ invoke_dataclass """
  238. acs = Access2(x, y)
  239. return acs.max()
  240. def test_access_attr_error():
  241. """ test_access """
  242. with pytest.raises(AttributeError):
  243. invoke_dataclass2(2, 1)
  244. def myfunc(x):
  245. """ myfunc """
  246. return x * x
  247. @ms_function
  248. def ms_infer_for():
  249. """ ms_infer_for """
  250. a = 0.0
  251. for x in [1.1, 2.3, 3.3]:
  252. a = a + x
  253. return a
  254. def test_infer_for():
  255. """ test_infer_for """
  256. ms_infer_for()
  257. @ms_function
  258. def ms_infer_for_func(y):
  259. """ ms_infer_for_func """
  260. for x in [1.0, 2.0, 3.0]:
  261. y = myfunc(x) + y
  262. return y
  263. def test_ms_infer_for_func():
  264. """ test_ms_infer_for_func """
  265. ms_infer_for_func(1.0)
  266. @ms_function
  267. def add(x, y):
  268. """ add """
  269. return x + y
  270. def test_add():
  271. """ test_add """
  272. res = add(1, 2.0)
  273. return res
  274. @ms_function
  275. def add_list():
  276. """ add_list """
  277. a = [1, 2, 3]
  278. b = a[1] + a[2]
  279. return b
  280. def test_list():
  281. """ test_list """
  282. return add_list()
  283. @ms_function
  284. def compare_list_len():
  285. """ compare_list_len """
  286. a = [1, 2, 3]
  287. return ms_len(a)
  288. def test_list_len():
  289. """ test_list_len """
  290. compare_list_len()
  291. @ms_function
  292. def add_tuple():
  293. """ add_tuple """
  294. a = (1, 2, 3)
  295. b = a[1] + a[2]
  296. return b
  297. def test_tuple():
  298. """ test_tuple """
  299. return add_tuple()
  300. def invoke_func(x):
  301. """ invoke_func """
  302. return x * x
  303. @ms_function
  304. def tuple_of_node(x, y):
  305. """ tuple_of_node """
  306. a = invoke_func(x)
  307. b = invoke_func(y)
  308. c = (a, b)
  309. d = c[1] * x
  310. return d
  311. def test_tuple_node():
  312. """ test_tuple_node """
  313. res = tuple_of_node(1, 2)
  314. return res
  315. @ms_function
  316. def range_spec(x, y):
  317. """ range_spec """
  318. for _ in range(1, 10, 3):
  319. x = x + 1
  320. return x + y
  321. def test_range():
  322. """ test_range """
  323. res = range_spec(10, 10)
  324. return res
  325. def test_expr():
  326. """ test const expr """
  327. a = (1, 2)
  328. @constexpr
  329. def tuple_len(x):
  330. assert len(x) == 2
  331. tuple_len(a)
  332. def test_tuple_to_array():
  333. """ test range tuple to array """
  334. range_x = range(10)
  335. res = F.tuple_to_array(range_x)
  336. print(res)