You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parse_method.py 9.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. @File : test_parse_method.py
  17. @Author:
  18. @Date : 2019-06-27
  19. @Desc : test parse the object's method
  20. """
  21. import logging
  22. import numpy as np
  23. import pytest
  24. from dataclasses import dataclass
  25. import mindspore.nn as nn
  26. from mindspore import context
  27. from mindspore._extends.parse.standard_method import ms_len
  28. from mindspore.common.api import ms_function
  29. from mindspore.common.tensor import Tensor
  30. from mindspore.ops.composite import core
  31. from ..ut_filter import non_graph_engine
  32. def setup_module(module):
  33. context.set_context(mode=context.PYNATIVE_MODE)
  34. log = logging.getLogger("test")
  35. log.setLevel(level=logging.ERROR)
  36. @ms_function
  37. def default_parameter_f(x, y=3):
  38. """ default_parameter_f """
  39. z = x + y
  40. return z
  41. # Test case: test parse fn that use default parameter
  42. def test_parse_defalut_parameter_case1():
  43. """ Test default parameter function call """
  44. log.debug("begin test_parse_defalut_parameter_case1")
  45. ret = default_parameter_f(2)
  46. log.debug("finished test_parse_defalut_parameter_case1, ret = %r", ret)
  47. def get_val_fn(x):
  48. """ get_val_fn """
  49. ret = x + 3
  50. return ret
  51. # Test case: test bool not
  52. @ms_function
  53. def bool_exp(x, y):
  54. """ bool_exp """
  55. return not x > y
  56. def test_bool_exp():
  57. """ test_bool_exp """
  58. bool_exp(1, 2)
  59. # Test case: use the variable parameter for @mindspore
  60. @ms_function
  61. def var_parameter_f(x, *args):
  62. """ var_parameter_f """
  63. z = x + args[0] + args[1] + args[2]
  64. return z
  65. def test_var_parameter_case1():
  66. """ test_var_parameter_case1 """
  67. log.debug("start test_var_parameter_case1")
  68. var_parameter_f(1, 2, 3, 4, 5)
  69. log.debug("end test_var_parameter_case1")
  70. class Net(nn.Cell):
  71. """ Net definition """
  72. def __init__(self, value1):
  73. super(Net, self).__init__()
  74. self.relu = nn.ReLU()
  75. self.softmax = nn.Softmax(0)
  76. self.axis = 0
  77. self.TC = ClassTest("test_class", 1.2)
  78. self.value = value1
  79. @ms_function
  80. def construct(self, x):
  81. x = self.get_test_value(x)
  82. return x
  83. def get_test_value(self, x):
  84. ret = x + self.value
  85. return ret
  86. class ClassTest:
  87. """ ClassTest definition """
  88. def __init__(self, name, value1):
  89. self.name = name
  90. self.value = value1
  91. def get_name(self):
  92. return self.name
  93. def get_value(self, inc):
  94. ret = self.value + inc
  95. return ret
  96. def __call__(self, *args, **kwargs):
  97. pass
  98. # Test: call method on parse graph code
  99. @non_graph_engine
  100. def test_call_method_on_construct():
  101. """ test_call_method_on_construct """
  102. log.debug("begin test_call_method_on_construct")
  103. x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
  104. y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
  105. z = np.array([[3, 5, 7], [2, 3, 5]]).astype(np.int32)
  106. net = Net(y)
  107. output = net.construct(x)
  108. result = output.asnumpy()
  109. print(result)
  110. assert np.all(result == z)
  111. log.debug("finished test_call_method_on_construct")
  112. # Test: call method on parse graph code
  113. class Net1(nn.Cell):
  114. """ Net1 definition """
  115. def __init__(self, v1, v2):
  116. super(Net1, self).__init__()
  117. self.relu = nn.ReLU()
  118. self.softmax = nn.Softmax(0)
  119. self.axis = 0
  120. self.TC = ClassTest("test_class", v1)
  121. self.value = v2
  122. @ms_function
  123. def construct(self, x):
  124. x = x + self.TC.get_value(self.value)
  125. return x
  126. @non_graph_engine
  127. def test_call_other_object_method():
  128. """ test_call_other_object_method """
  129. log.debug("begin test_call_other_object_method")
  130. x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
  131. y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
  132. y1 = Tensor(np.array([[5, 4, 5], [1, 1, 2]]).astype(np.int32))
  133. z = np.array([[8, 9, 12], [3, 4, 7]]).astype(np.int32)
  134. net = Net1(y, y1)
  135. with pytest.raises(TypeError):
  136. output = net.construct(x)
  137. result = output.asnumpy()
  138. print(result)
  139. assert np.all(result == z)
  140. log.debug("finished test_call_other_object_method")
  141. # Test: call global object method(not self) on parse graph code
  142. value = Tensor(np.array([[3, 4, 5], [1, 1, 2]]).astype(np.int32))
  143. TC = ClassTest("test_class", value)
  144. class Net2(nn.Cell):
  145. """ Net2 definition """
  146. def __init__(self, value1):
  147. super(Net2, self).__init__()
  148. self.value = value1
  149. @ms_function
  150. def construct(self, x):
  151. x = x + TC.get_value(self.value)
  152. return x
  153. @ms_function
  154. def construct1(self, x):
  155. x = x + TC.value
  156. x = x + self.value
  157. return x
  158. @non_graph_engine
  159. def test_call_no_self_other_object_method():
  160. """ test_call_no_self_other_object_method """
  161. log.debug("begin test_call_other_object_method")
  162. x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
  163. y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
  164. z = np.array([[6, 9, 12], [3, 4, 7]]).astype(np.int32)
  165. net = Net2(y)
  166. with pytest.raises(TypeError):
  167. output = net.construct(x)
  168. result = output.asnumpy()
  169. print(result)
  170. assert np.all(result == z)
  171. log.debug("finished test_call_other_object_method")
  172. def test_call_no_self_other_object_attr_value():
  173. """ test_call_no_self_other_object_attr_value """
  174. # do not support tensor as init input.
  175. return
  176. # Test case: use the * to unlock the varargs for @mindspore
  177. def vararg1(x, y):
  178. """ vararg1 """
  179. z = x + y
  180. return z
  181. def varargs_main(fn):
  182. """ varargs_main """
  183. @ms_function
  184. def t1(*args):
  185. return fn(*args)
  186. return t1
  187. def test_var_parameter_case3():
  188. """ test_var_parameter_case3 """
  189. log.debug("start test_var_parameter_case3")
  190. ret = varargs_main(vararg1)(1, 2)
  191. log.debug("ret = %r", ret)
  192. log.debug("end test_var_parameter_case3")
  193. # Test case: test the flag set
  194. @core(tg=True)
  195. def set_flag(x):
  196. """ set_flag """
  197. return x + 1
  198. @ms_function
  199. def set_test_flag_main(x, y):
  200. """ set_test_flag_main """
  201. z = set_flag(x)
  202. z = z + y
  203. return z
  204. def test_set_flag():
  205. """ Test default parameter function call """
  206. log.debug("begin test_set_flag")
  207. ret = set_test_flag_main(2, 3)
  208. log.debug("finished test_set_flag, ret = %r", ret)
  209. @dataclass
  210. class Access:
  211. a: int
  212. b: int
  213. def max(self):
  214. if self.a > self.b:
  215. return self.a
  216. return self.b
  217. @ms_function
  218. def invoke_dataclass(x, y):
  219. """ invoke_dataclass """
  220. acs = Access(x, y)
  221. return acs.max()
  222. def test_access():
  223. """ test_access """
  224. invoke_dataclass(1, 2)
  225. def myfunc(x):
  226. """ myfunc """
  227. return x * x
  228. @ms_function
  229. def ms_infer_for():
  230. """ ms_infer_for """
  231. a = 0.0
  232. for x in [1.1, 2.3, 3.3]:
  233. a = a + x
  234. return a
  235. def test_infer_for():
  236. """ test_infer_for """
  237. ms_infer_for()
  238. @ms_function
  239. def ms_infer_for_func(y):
  240. """ ms_infer_for_func """
  241. for x in [1.0, 2.0, 3.0]:
  242. y = myfunc(x) + y
  243. return y
  244. def test_ms_infer_for_func():
  245. """ test_ms_infer_for_func """
  246. ms_infer_for_func(1.0)
  247. @ms_function
  248. def add(x, y):
  249. """ add """
  250. return x + y
  251. def test_add():
  252. """ test_add """
  253. res = add(1, 2.0)
  254. return res
  255. @ms_function
  256. def add_list():
  257. """ add_list """
  258. a = [1, 2, 3]
  259. b = a[1] + a[2]
  260. return b
  261. def test_list():
  262. """ test_list """
  263. return add_list()
  264. @ms_function
  265. def compare_list_len():
  266. """ compare_list_len """
  267. a = [1, 2, 3]
  268. return ms_len(a)
  269. def test_list_len():
  270. """ test_list_len """
  271. compare_list_len()
  272. @ms_function
  273. def add_tuple():
  274. """ add_tuple """
  275. a = (1, 2, 3)
  276. b = a[1] + a[2]
  277. return b
  278. def test_tuple():
  279. """ test_tuple """
  280. return add_tuple()
  281. def invoke_func(x):
  282. """ invoke_func """
  283. return x * x
  284. @ms_function
  285. def tuple_of_node(x, y):
  286. """ tuple_of_node """
  287. a = invoke_func(x)
  288. b = invoke_func(y)
  289. c = (a, b)
  290. d = c[1] * x
  291. return d
  292. def test_tuple_node():
  293. """ test_tuple_node """
  294. res = tuple_of_node(1, 2)
  295. return res
  296. @ms_function
  297. def range_spec(x, y):
  298. """ range_spec """
  299. for _ in range(1, 10, 3):
  300. x = x + 1
  301. return x + y
  302. def test_range():
  303. """ test_range """
  304. res = range_spec(10, 10)
  305. return res