You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graph_fallback.py 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test graph fallback """
  16. import pytest
  17. import numpy as np
  18. import mindspore.nn as nn
  19. from mindspore import Tensor, ms_function, context
  20. import mindspore.common.dtype as mstype
  21. context.set_context(mode=context.GRAPH_MODE)
  22. class ControlNet(nn.Cell):
  23. def inner_function_1(self, a, b):
  24. return a + b
  25. def inner_function_2(self, a, b):
  26. return a - b
  27. def construct(self, x):
  28. a = Tensor(np.array(4), mstype.int32)
  29. b = Tensor(np.array(5), mstype.int32)
  30. if a + b > x:
  31. return self.inner_function_1(a, b)
  32. return self.inner_function_2(a, b)
  33. @pytest.mark.level0
  34. @pytest.mark.platform_x86_gpu_training
  35. @pytest.mark.platform_arm_ascend_training
  36. @pytest.mark.platform_x86_ascend_training
  37. @pytest.mark.env_onecard
  38. def test_fallback_control_sink_tensor():
  39. """
  40. Feature: Fallback feature: support define Tensor in Class construct.
  41. Description: Fallback feature: support define Tensor in Class construct.
  42. Expectation: Fallback feature: support define Tensor in Class construct.
  43. """
  44. x = Tensor(np.array(1), mstype.int32)
  45. net = ControlNet()
  46. output = net(x)
  47. output_expect = Tensor(9, mstype.int32)
  48. assert output == output_expect
  49. @pytest.mark.level0
  50. @pytest.mark.platform_x86_gpu_training
  51. @pytest.mark.platform_arm_ascend_training
  52. @pytest.mark.platform_x86_ascend_training
  53. @pytest.mark.env_onecard
  54. def test_np_tensor_list():
  55. """
  56. Feature: Fallback feature
  57. Description: support Basic method of Tensor list.
  58. Expectation: No exception.
  59. """
  60. @ms_function
  61. def np_tensor_list():
  62. a = Tensor(np.array(4), mstype.int32)
  63. b = Tensor(np.array(5), mstype.int32)
  64. c = Tensor(np.array(6), mstype.int32)
  65. tensor_list = [a, b]
  66. for tensor in tensor_list:
  67. print(tensor)
  68. tensor_list.append(tensor_list[-1] + c)
  69. return tensor_list
  70. tensor_list = np_tensor_list()
  71. print("tensor_list:", tensor_list)
  72. assert len(tensor_list) == 3
  73. @pytest.mark.level0
  74. @pytest.mark.platform_x86_gpu_training
  75. @pytest.mark.platform_arm_ascend_training
  76. @pytest.mark.platform_x86_ascend_training
  77. @pytest.mark.env_onecard
  78. def test_list_count():
  79. """
  80. Feature: Fallback feature
  81. Description: support attr/method of builtin type.
  82. Expectation: No exception.
  83. """
  84. @ms_function
  85. def list_count():
  86. x = list([1, 2, 3])
  87. res = x.count(1)
  88. return res
  89. assert list_count() == 1
  90. @pytest.mark.level0
  91. @pytest.mark.platform_x86_gpu_training
  92. @pytest.mark.platform_arm_ascend_training
  93. @pytest.mark.platform_x86_ascend_training
  94. @pytest.mark.env_onecard
  95. def test_list_append():
  96. """
  97. Feature: Fallback feature
  98. Description: support attr/method of builtin type.
  99. Expectation: No exception.
  100. """
  101. @ms_function
  102. def list_append():
  103. x = list([1, 2, 3])
  104. x.append(4)
  105. return Tensor(x)
  106. assert np.all(list_append().asnumpy() == np.array([1, 2, 3, 4]))
  107. @pytest.mark.level0
  108. @pytest.mark.platform_x86_gpu_training
  109. @pytest.mark.platform_arm_ascend_training
  110. @pytest.mark.platform_x86_ascend_training
  111. @pytest.mark.env_onecard
  112. def test_list_insert_1():
  113. """
  114. Feature: Fallback feature
  115. Description: support attr/method of builtin type.
  116. Expectation: No exception.
  117. """
  118. @ms_function
  119. def list_insert():
  120. x = list([1, 3, 4])
  121. x.insert(0, 2)
  122. return Tensor(x)
  123. assert np.all(list_insert().asnumpy() == np.array([2, 1, 3, 4]))
  124. @pytest.mark.level0
  125. @pytest.mark.platform_x86_gpu_training
  126. @pytest.mark.platform_arm_ascend_training
  127. @pytest.mark.platform_x86_ascend_training
  128. @pytest.mark.env_onecard
  129. def test_list_insert_2():
  130. """
  131. Feature: Fallback feature
  132. Description: support attr/method of builtin type.
  133. Expectation: No exception.
  134. """
  135. @ms_function
  136. def list_insert():
  137. x = list([1, 3, 4])
  138. x.insert(5, 2)
  139. return Tensor(x)
  140. assert np.all(list_insert().asnumpy() == np.array([1, 3, 4, 2]))
  141. @pytest.mark.platform_x86_gpu_training
  142. @pytest.mark.platform_arm_ascend_training
  143. @pytest.mark.platform_x86_ascend_training
  144. @pytest.mark.env_onecard
  145. def test_list_insert_3():
  146. """
  147. Feature: Fallback feature
  148. Description: support attr/method of builtin type.
  149. Expectation: No exception.
  150. """
  151. @ms_function
  152. def list_insert():
  153. x = list([1, 3, 4])
  154. x.insert(-1, 2)
  155. return Tensor(x)
  156. assert np.all(list_insert().asnumpy() == np.array([1, 3, 2, 4]))
  157. @pytest.mark.platform_x86_gpu_training
  158. @pytest.mark.platform_arm_ascend_training
  159. @pytest.mark.platform_x86_ascend_training
  160. @pytest.mark.env_onecard
  161. def test_list_insert_4():
  162. """
  163. Feature: Fallback feature
  164. Description: support attr/method of builtin type.
  165. Expectation: No exception.
  166. """
  167. @ms_function
  168. def list_insert():
  169. x = list([1, 3, 4])
  170. x.insert(-5, 2)
  171. return Tensor(x)
  172. assert np.all(list_insert().asnumpy() == np.array([2, 1, 3, 4]))
  173. @ms_function
  174. def np_fallback_func_tensor_index(x):
  175. array_x = tuple([2, 3, 4, 5])
  176. np_x = np.array(array_x).astype(np.float32)
  177. me_x = Tensor(np_x)
  178. me_x = me_x + me_x
  179. return me_x[x]
  180. @pytest.mark.level0
  181. @pytest.mark.platform_x86_gpu_training
  182. @pytest.mark.platform_arm_ascend_training
  183. @pytest.mark.platform_x86_ascend_training
  184. @pytest.mark.env_onecard
  185. def test_np_fallback_func_tensor_index():
  186. """
  187. Feature: Fallback feature: support Tensor index.
  188. Description: Fallback feature: support Tensor index.
  189. Expectation: Fallback feature: support Tensor index.
  190. """
  191. x = Tensor(1, mstype.int32)
  192. output = np_fallback_func_tensor_index(x)
  193. output_expect = Tensor(6, mstype.float32)
  194. assert output == output_expect
  195. @pytest.mark.level0
  196. @pytest.mark.platform_x86_gpu_training
  197. @pytest.mark.platform_arm_ascend_training
  198. @pytest.mark.platform_x86_ascend_training
  199. @pytest.mark.env_onecard
  200. def test_np_calculate():
  201. """
  202. Feature: Fallback feature.
  203. Description: Support numpy calculation.
  204. Expectation: No exception.
  205. """
  206. @ms_function
  207. def np_calculate():
  208. x = np.array([3, 1, 2, 4, 5])
  209. y = x % 2
  210. z = Tensor(y)
  211. return z
  212. assert np.all(np_calculate().asnumpy() == np.array([1, 1, 0, 0, 1]))
  213. @pytest.mark.level0
  214. @pytest.mark.platform_x86_gpu_training
  215. @pytest.mark.platform_arm_ascend_training
  216. @pytest.mark.platform_x86_ascend_training
  217. @pytest.mark.env_onecard
  218. def test_fallback_tensor_array_astype():
  219. """
  220. Feature: JIT Fallback
  221. Description: Test Tensor(array) with astype() in graph mode.
  222. Expectation: No exception.
  223. """
  224. @ms_function
  225. def foo():
  226. me_x = Tensor([1.1, -2.1]).astype("float32")
  227. return me_x
  228. print(foo())