You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parse_method.py 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. @File : test_parse_method.py
  17. @Author:
  18. @Date : 2019-06-27
  19. @Desc : test parse the object's method
  20. """
  21. import logging
  22. import numpy as np
  23. import mindspore.nn as nn
  24. from dataclasses import dataclass
  25. from mindspore.common.api import ms_function
  26. from mindspore.common.tensor import Tensor
  27. from mindspore.ops.composite import core
  28. from mindspore._extends.parse.standard_method import ms_len
  29. from ..ut_filter import non_graph_engine
  30. import pytest
  31. from mindspore import context
  32. def setup_module(module):
  33. context.set_context(mode=context.PYNATIVE_MODE)
  34. log = logging.getLogger("test")
  35. log.setLevel(level=logging.ERROR)
  36. @ms_function
  37. def default_parameter_f(x, y=3):
  38. """ default_parameter_f """
  39. z = x + y
  40. return z
  41. # Test case: test parse fn that use default parameter
  42. def test_parse_defalut_parameter_case1():
  43. """ Test default parameter function call """
  44. log.debug("begin test_parse_defalut_parameter_case1")
  45. ret = default_parameter_f(2)
  46. log.debug("finished test_parse_defalut_parameter_case1, ret = %r", ret)
  47. def get_val_fn(x):
  48. """ get_val_fn """
  49. ret = x + 3
  50. return ret
  51. # Test case: test bool not
  52. @ms_function
  53. def bool_exp(x, y):
  54. """ bool_exp """
  55. return not x > y
  56. def test_bool_exp():
  57. """ test_bool_exp """
  58. bool_exp(1, 2)
  59. # Test case: use the variable parameter for @mindspore
  60. @ms_function
  61. def var_parameter_f(x, *args):
  62. """ var_parameter_f """
  63. z = x + args[0] + args[1] + args[2]
  64. return z
  65. def test_var_parameter_case1():
  66. """ test_var_parameter_case1 """
  67. log.debug("start test_var_parameter_case1")
  68. var_parameter_f(1, 2, 3, 4, 5)
  69. log.debug("end test_var_parameter_case1")
  70. class Net(nn.Cell):
  71. """ Net definition """
  72. def __init__(self, value1):
  73. super(Net, self).__init__()
  74. self.relu = nn.ReLU()
  75. self.softmax = nn.Softmax(0)
  76. self.axis = 0
  77. self.TC = ClassTest("test_class", 1.2)
  78. self.value = value1
  79. @ms_function
  80. def construct(self, x):
  81. x = self.get_test_value(x)
  82. return x
  83. def get_test_value(self, x):
  84. ret = x + self.value
  85. return ret
  86. class ClassTest:
  87. """ ClassTest definition """
  88. def __init__(self, name, value1):
  89. self.name = name
  90. self.value = value1
  91. def get_name(self):
  92. return self.name
  93. def get_value(self, inc):
  94. ret = self.value + inc
  95. return ret
  96. def __call__(self, *args, **kwargs):
  97. pass
  98. # Test: call method on parse graph code
  99. @non_graph_engine
  100. def test_call_method_on_construct():
  101. """ test_call_method_on_construct """
  102. log.debug("begin test_call_method_on_construct")
  103. x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
  104. y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
  105. z = np.array([[3, 5, 7], [2, 3, 5]]).astype(np.int32)
  106. net = Net(y)
  107. output = net.construct(x)
  108. result = output.asnumpy()
  109. print(result)
  110. assert np.all(result == z)
  111. log.debug("finished test_call_method_on_construct")
  112. # Test: call method on parse graph code
  113. class Net1(nn.Cell):
  114. """ Net1 definition """
  115. def __init__(self, v1, v2):
  116. super(Net1, self).__init__()
  117. self.relu = nn.ReLU()
  118. self.softmax = nn.Softmax(0)
  119. self.axis = 0
  120. self.TC = ClassTest("test_class", v1)
  121. self.value = v2
  122. @ms_function
  123. def construct(self, x):
  124. x = x + self.TC.get_value(self.value)
  125. return x
  126. @non_graph_engine
  127. def test_call_other_object_method():
  128. """ test_call_other_object_method """
  129. log.debug("begin test_call_other_object_method")
  130. x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
  131. y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
  132. y1 = Tensor(np.array([[5, 4, 5], [1, 1, 2]]).astype(np.int32))
  133. z = np.array([[8, 9, 12], [3, 4, 7]]).astype(np.int32)
  134. net = Net1(y, y1)
  135. with pytest.raises(TypeError):
  136. output = net.construct(x)
  137. result = output.asnumpy()
  138. print(result)
  139. assert np.all(result == z)
  140. log.debug("finished test_call_other_object_method")
  141. # Test: call global object method(not self) on parse graph code
  142. value = Tensor(np.array([[3, 4, 5], [1, 1, 2]]).astype(np.int32))
  143. TC = ClassTest("test_class", value)
  144. class Net2(nn.Cell):
  145. """ Net2 definition """
  146. def __init__(self, value1):
  147. super(Net2, self).__init__()
  148. self.value = value1
  149. @ms_function
  150. def construct(self, x):
  151. x = x + TC.get_value(self.value)
  152. return x
  153. @ms_function
  154. def construct1(self, x):
  155. x = x + TC.value
  156. x = x + self.value
  157. return x
  158. @non_graph_engine
  159. def test_call_no_self_other_object_method():
  160. """ test_call_no_self_other_object_method """
  161. log.debug("begin test_call_other_object_method")
  162. x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
  163. y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
  164. z = np.array([[6, 9, 12], [3, 4, 7]]).astype(np.int32)
  165. net = Net2(y)
  166. with pytest.raises(TypeError):
  167. output = net.construct(x)
  168. result = output.asnumpy()
  169. print(result)
  170. assert np.all(result == z)
  171. log.debug("finished test_call_other_object_method")
  172. def test_call_no_self_other_object_attr_value():
  173. """ test_call_no_self_other_object_attr_value """
  174. # do not support tensor as init input.
  175. return
  176. # Test case: use the * to unlock the varargs for @mindspore
  177. def vararg1(x, y):
  178. """ vararg1 """
  179. z = x + y
  180. return z
  181. def varargs_main(fn):
  182. """ varargs_main """
  183. @ms_function
  184. def t1(*args):
  185. return fn(*args)
  186. return t1
  187. def test_var_parameter_case3():
  188. """ test_var_parameter_case3 """
  189. log.debug("start test_var_parameter_case3")
  190. ret = varargs_main(vararg1)(1, 2)
  191. log.debug("ret = %r", ret)
  192. log.debug("end test_var_parameter_case3")
  193. # Test case: test the flag set
  194. @core(tg=True)
  195. def set_flag(x):
  196. """ set_flag """
  197. return x + 1
  198. @ms_function
  199. def set_test_flag_main(x, y):
  200. """ set_test_flag_main """
  201. z = set_flag(x)
  202. z = z + y
  203. return z
  204. def test_set_flag():
  205. """ Test default parameter function call """
  206. log.debug("begin test_set_flag")
  207. ret = set_test_flag_main(2, 3)
  208. log.debug("finished test_set_flag, ret = %r", ret)
  209. @dataclass
  210. class Access:
  211. a: int
  212. b: int
  213. def max(self):
  214. if self.a > self.b:
  215. return self.a
  216. return self.b
  217. @ms_function
  218. def invoke_dataclass(x, y):
  219. """ invoke_dataclass """
  220. acs = Access(x, y)
  221. return acs.max()
  222. def test_access():
  223. """ test_access """
  224. invoke_dataclass(1, 2)
  225. def myfunc(x):
  226. """ myfunc """
  227. return x * x
  228. @ms_function
  229. def ms_infer_for():
  230. """ ms_infer_for """
  231. a = 0.0
  232. for x in [1.1, 2.3, 3.3]:
  233. a = a + x
  234. return a
  235. def test_infer_for():
  236. """ test_infer_for """
  237. ms_infer_for()
  238. @ms_function
  239. def ms_infer_for_func(y):
  240. """ ms_infer_for_func """
  241. for x in [1.0, 2.0, 3.0]:
  242. y = myfunc(x) + y
  243. return y
  244. def test_ms_infer_for_func():
  245. """ test_ms_infer_for_func """
  246. ms_infer_for_func(1.0)
  247. @ms_function
  248. def add(x, y):
  249. """ add """
  250. return x + y
  251. def test_add():
  252. """ test_add """
  253. res = add(1, 2.0)
  254. return res
  255. @ms_function
  256. def add_list():
  257. """ add_list """
  258. a = [1, 2, 3]
  259. b = a[1] + a[2]
  260. return b
  261. def test_list():
  262. """ test_list """
  263. return add_list()
  264. @ms_function
  265. def compare_list_len():
  266. """ compare_list_len """
  267. a = [1, 2, 3]
  268. return ms_len(a)
  269. def test_list_len():
  270. """ test_list_len """
  271. compare_list_len()
  272. @ms_function
  273. def add_tuple():
  274. """ add_tuple """
  275. a = (1, 2, 3)
  276. b = a[1] + a[2]
  277. return b
  278. def test_tuple():
  279. """ test_tuple """
  280. return add_tuple()
  281. def invoke_func(x):
  282. """ invoke_func """
  283. return x * x
  284. @ms_function
  285. def tuple_of_node(x, y):
  286. """ tuple_of_node """
  287. a = invoke_func(x)
  288. b = invoke_func(y)
  289. c = (a, b)
  290. d = c[1] * x
  291. return d
  292. def test_tuple_node():
  293. """ test_tuple_node """
  294. res = tuple_of_node(1, 2)
  295. return res
  296. @ms_function
  297. def range_spec(x, y):
  298. """ range_spec """
  299. for _ in range(1, 10, 3):
  300. x = x + 1
  301. return x + y
  302. def test_range():
  303. """ test_range """
  304. res = range_spec(10, 10)
  305. return res