You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graph_fallback.py 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test graph fallback """
  16. import pytest
  17. import numpy as np
  18. import mindspore.nn as nn
  19. from mindspore import Tensor, ms_function, context
  20. from mindspore.ops import operations as P
  21. from mindspore.ops import functional as F
  22. import mindspore.common.dtype as mstype
  23. import mindspore.common._monad as monad
  24. context.set_context(mode=context.GRAPH_MODE)
  25. # `add_func` is defined in current file.
  26. def add_func(x, y):
  27. return x + y
  28. @ms_function
  29. def do_increment(i):
  30. add_1 = F.partial(add_func, 1)
  31. return add_1(i)
  32. def test_increment():
  33. a = do_increment(9)
  34. assert a == 10
  35. @ms_function
  36. def use_monad(x, y):
  37. res = P.Mul()(x, y)
  38. res = F.depend(res, monad.U)
  39. return res
  40. def test_use_monad():
  41. x = Tensor(1.0, mstype.float32)
  42. y = Tensor(1.0, mstype.float32)
  43. print(use_monad(x, y))
  44. @ms_function
  45. def use_tensor_with_mstype():
  46. me_x = Tensor(1, mstype.int32)
  47. return me_x
  48. @pytest.mark.skip(reason='Not support graph fallback feature yet')
  49. def test_tensor_with_mstype():
  50. """
  51. Feature: JIT Fallback
  52. Description: Test tensor with mstype in graph mode.
  53. Expectation: No exception.
  54. """
  55. print(use_tensor_with_mstype())
  56. class Net(nn.Cell):
  57. def __init__(self):
  58. super(Net, self).__init__()
  59. self.x = Tensor([2, 3, 4])
  60. def construct(self):
  61. x_len = len(self.x)
  62. for i in range(x_len):
  63. print(i)
  64. return x_len
  65. def test_builtins_len():
  66. net = Net()
  67. net()
  68. @ms_function
  69. def np_fallback_func():
  70. array_x = tuple([2, 3, 4, 5])
  71. np_x = np.array(array_x).astype(np.float32)
  72. me_x = Tensor(np_x)
  73. me_x = me_x + me_x
  74. return me_x
  75. def test_np_fallback_func():
  76. print(np_fallback_func())
  77. # Test `return` interpret node.
  78. @ms_function
  79. def div_mod_func1():
  80. x = 8
  81. y = 3
  82. a = divmod(x, y)
  83. return Tensor(a)
  84. def test_div_mod_func1():
  85. print(div_mod_func1()) # (2, 2)
  86. # Test interpret node with parameters as input.
  87. @ms_function
  88. def div_mod_func2(x, y):
  89. a = divmod(x, y)
  90. return Tensor(a)
  91. def test_div_mod_func2_scalar():
  92. """
  93. Feature: JIT Fallback
  94. Description: Test divmod in graph.
  95. Expectation: No exception.
  96. """
  97. print(div_mod_func2(8, 3)) # (2, 2)
  98. @pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
  99. def test_div_mod_func2_tensor():
  100. """
  101. Feature: JIT Fallback
  102. Description: Test divmod with Tensor input in graph. We'll support it in Tensor Input Fallback solution.
  103. Expectation: Not supported exception.
  104. """
  105. with pytest.raises(RuntimeError) as err:
  106. print(div_mod_func2(Tensor(8), Tensor(3)))
  107. assert "Not support Tensor or variable type as input during running JIT Fallback, but got" in str(err.value)
  108. # NameError: name 'Tensor' is not defined.
  109. @ms_function
  110. def select_func(cond, x, y):
  111. if isinstance(cond, (tuple, list)):
  112. output = y
  113. elif isinstance(cond, Tensor):
  114. output = F.select(cond, x, y)
  115. else:
  116. output = x
  117. return output
  118. def test_select_func():
  119. cond = Tensor([True, False])
  120. x = Tensor([2, 3], mstype.float32)
  121. y = Tensor([1, 2], mstype.float32)
  122. print(select_func(cond, x, y))
  123. # Not interpret 'Tensor'.
  124. @ms_function
  125. def select_func2(cond, x, y):
  126. if isinstance(cond, (tuple, list)):
  127. output = y
  128. if isinstance(cond, Tensor):
  129. output = F.select(cond, x, y)
  130. else:
  131. output = x
  132. return output
  133. def test_select_func2():
  134. cond = Tensor([True, False])
  135. x = Tensor([2, 3], mstype.float32)
  136. y = Tensor([1, 2], mstype.float32)
  137. print(select_func2(cond, x, y))
  138. # NameError: name 'Tensor' is not defined.
  139. @ms_function
  140. def slice_func(a, b):
  141. a[1:3, ::] = b
  142. return a
  143. def test_slice_func():
  144. a = Tensor(np.arange(60).reshape(3, 4, 5), dtype=mstype.float32)
  145. b = Tensor([1], dtype=mstype.float32)
  146. print(slice_func(a, b))
  147. @ms_function
  148. def np_fallback_func_tensor_index(x):
  149. array_x = tuple([2, 3, 4, 5])
  150. np_x = np.array(array_x).astype(np.float32)
  151. me_x = Tensor(np_x)
  152. me_x = me_x + me_x
  153. return me_x[x]
  154. # NameError: name 'array_x' is not defined.
  155. @pytest.mark.skip(reason='Not support graph fallback feature yet')
  156. def test_np_fallback_func_tensor_index():
  157. """
  158. Feature: Fallback feature: support Tensor index.
  159. Description: Fallback feature: support Tensor index.
  160. Expectation: Fallback feature: support Tensor index.
  161. """
  162. x = Tensor(1, mstype.int32)
  163. output = np_fallback_func_tensor_index(x)
  164. output_expect = Tensor(6, mstype.float32)
  165. assert output == output_expect
  166. class ControlNet(nn.Cell):
  167. def __init__(self):
  168. super(ControlNet, self).__init__()
  169. def inner_function_1(self, a, b):
  170. return a + b
  171. def inner_function_2(self, a, b):
  172. return a - b
  173. def construct(self, x):
  174. a = Tensor(np.array(4), mstype.int32)
  175. b = Tensor(np.array(5), mstype.int32)
  176. if a + b > x:
  177. return self.inner_function_1(a, b)
  178. return self.inner_function_2(a, b)
  179. # NameError: name 'mstype' is not defined.
  180. @pytest.mark.skip(reason='Not support graph fallback feature yet')
  181. def test_fallback_control_sink_tensor():
  182. """
  183. Feature: Fallback feature: support define Tensor in Class construct.
  184. Description: Fallback feature: support define Tensor in Class construct.
  185. Expectation: Fallback feature: support define Tensor in Class construct.
  186. """
  187. x = Tensor(np.array(1), mstype.int32)
  188. net = ControlNet()
  189. output = net(x)
  190. output_expect = Tensor(9, mstype.int32)
  191. assert output == output_expect
  192. # NameError: name 'mytype' is not defined
  193. @pytest.mark.skip(reason='Not support graph fallback feature yet')
  194. def test_np_tensor_list():
  195. """
  196. Feature: Fallback feature
  197. Description: support Basic method of Tensor list.
  198. Expectation: No exception.
  199. """
  200. @ms_function
  201. def np_tensor_list():
  202. a = Tensor(np.array(4), mstype.int32)
  203. b = Tensor(np.array(5), mstype.int32)
  204. c = Tensor(np.array(6), mstype.int32)
  205. tensor_list = [a, b]
  206. for tensor in tensor_list:
  207. print(tensor)
  208. tensor_list.append(tensor_list[-1] + c)
  209. return tensor_list
  210. tensor_list = np_tensor_list()
  211. print("tensor_list:", tensor_list)
  212. assert len(tensor_list) == 3
  213. # EvalCNode: This may be not defined, or it can't be a operator.
  214. @pytest.mark.skip(reason='Not support graph fallback feature yet')
  215. def test_np_tensor_add():
  216. """
  217. Feature: Fallback feature
  218. Description: support Tensor add.
  219. Expectation: No exception.
  220. """
  221. @ms_function
  222. def np_tensor_add():
  223. a = Tensor(np.array(4))
  224. b = Tensor(np.array(5))
  225. tensor_list = [a, b]
  226. for tensor in tensor_list:
  227. print(tensor)
  228. x = 6
  229. np_x = np.array(x)
  230. c = Tensor(np_x)
  231. d = tensor_list[-1] + c
  232. tensor_list.append(d)
  233. return tensor_list
  234. tensor_list = np_tensor_add()
  235. print("tensor_list:", tensor_list)
  236. assert tensor_list[-1] == 11