You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graph_fallback.py 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test graph fallback """
  16. import functools
  17. import pytest
  18. import numpy as np
  19. import mindspore.nn as nn
  20. from mindspore import Tensor, ms_function, context
  21. from mindspore.ops import operations as P
  22. from mindspore.ops import functional as F
  23. from mindspore.nn.probability import distribution
  24. import mindspore.common.dtype as mstype
  25. import mindspore.common._monad as monad
  26. import mindspore.scipy.linalg as alg
  27. context.set_context(mode=context.GRAPH_MODE)
  28. # `add_func` is defined in current file.
  29. def add_func(x, y):
  30. return x + y
  31. @ms_function
  32. def do_increment(i):
  33. add_1 = F.partial(add_func, 1)
  34. return add_1(i)
  35. def test_increment():
  36. a = do_increment(9)
  37. assert a == 10
  38. @ms_function
  39. def use_monad(x, y):
  40. res = P.Mul()(x, y)
  41. res = F.depend(res, monad.U)
  42. return res
  43. def test_use_monad():
  44. x = Tensor(1.0, mstype.float32)
  45. y = Tensor(1.0, mstype.float32)
  46. print(use_monad(x, y))
  47. @ms_function
  48. def use_tuple_of_tensor():
  49. me_x = (Tensor(1), Tensor(1))
  50. return me_x
  51. def test_tuple_of_tensor():
  52. """
  53. Feature: JIT Fallback
  54. Description: Test tuple of tensor in graph mode.
  55. Expectation: No exception.
  56. """
  57. print(use_tuple_of_tensor())
  58. @ms_function
  59. def use_list_of_tensor():
  60. me_x = [Tensor(1), Tensor(1)]
  61. return me_x
  62. def test_list_of_tensor():
  63. """
  64. Feature: JIT Fallback
  65. Description: Test list of tensor in graph mode.
  66. Expectation: No exception.
  67. """
  68. print(use_list_of_tensor())
  69. class Net(nn.Cell):
  70. def __init__(self):
  71. super(Net, self).__init__()
  72. self.x = Tensor([2, 3, 4])
  73. def construct(self):
  74. x_len = len(self.x)
  75. for i in range(x_len):
  76. print(i)
  77. return x_len
  78. def test_builtins_len():
  79. net = Net()
  80. net()
  81. @ms_function
  82. def np_fallback_func():
  83. array_x = tuple([2, 3, 4, 5])
  84. np_x = np.array(array_x).astype(np.float32)
  85. me_x = Tensor(np_x)
  86. me_x = me_x + me_x
  87. return me_x
  88. def test_np_fallback_func():
  89. print(np_fallback_func())
  90. # Test `return` interpret node.
  91. @ms_function
  92. def div_mod_func1():
  93. x = 8
  94. y = 3
  95. a = divmod(x, y)
  96. return Tensor(a)
  97. def test_div_mod_func1():
  98. print(div_mod_func1()) # (2, 2)
  99. # Test interpret node with parameters as input.
  100. @ms_function
  101. def div_mod_func2(x, y):
  102. a = divmod(x, y)
  103. return Tensor(a)
  104. def test_div_mod_func2_scalar():
  105. """
  106. Feature: JIT Fallback
  107. Description: Test divmod in graph.
  108. Expectation: No exception.
  109. """
  110. print(div_mod_func2(8, 3)) # (2, 2)
  111. @pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
  112. def test_div_mod_func2_tensor():
  113. """
  114. Feature: JIT Fallback
  115. Description: Test divmod with Tensor input in graph. We'll support it in Tensor Input Fallback solution.
  116. Expectation: Not supported exception.
  117. """
  118. with pytest.raises(RuntimeError) as err:
  119. print(div_mod_func2(Tensor(8), Tensor(3)))
  120. assert "Not support Tensor or variable type as input during running JIT Fallback, but got" in str(err.value)
  121. @ms_function
  122. def select_func(cond, x, y):
  123. if isinstance(cond, (tuple, list)):
  124. output = y
  125. elif isinstance(cond, Tensor):
  126. output = F.select(cond, x, y)
  127. else:
  128. output = x
  129. return output
  130. def test_select_func():
  131. cond = Tensor([True, False])
  132. x = Tensor([2, 3], mstype.float32)
  133. y = Tensor([1, 2], mstype.float32)
  134. print(select_func(cond, x, y))
  135. @ms_function
  136. def select_func2(cond, x, y):
  137. if isinstance(cond, (tuple, list)):
  138. output = y
  139. if isinstance(cond, Tensor):
  140. output = F.select(cond, x, y)
  141. else:
  142. output = x
  143. return output
  144. def test_select_func2():
  145. cond = Tensor([True, False])
  146. x = Tensor([2, 3], mstype.float32)
  147. y = Tensor([1, 2], mstype.float32)
  148. print(select_func2(cond, x, y))
  149. @ms_function
  150. def slice_func(a, b):
  151. a[1:3, ::] = b
  152. return a
  153. def test_slice_func():
  154. a = Tensor(np.arange(60).reshape(3, 4, 5), dtype=mstype.float32)
  155. b = Tensor([1], dtype=mstype.float32)
  156. print(slice_func(a, b))
  157. def test_context():
  158. """
  159. Feature: JIT Fallback
  160. Description: Test context in graph.
  161. Expectation: No exception.
  162. """
  163. class ContextNet(nn.Cell):
  164. def __init__(self):
  165. super(ContextNet, self).__init__()
  166. self.mode = context.get_context("mode")
  167. def construct(self):
  168. out = 1
  169. if self.mode == context.GRAPH_MODE:
  170. out = 2
  171. return out
  172. net = ContextNet()
  173. out = net()
  174. print(out)
  175. def test_scipy_module():
  176. """
  177. Feature: JIT Fallback
  178. Description: Test scipy module in graph.
  179. Expectation: No exception.
  180. """
  181. class Network(nn.Cell):
  182. def construct(self, x):
  183. return alg.eigh(x)
  184. net = Network()
  185. x = Tensor([[2, 0, 0, 0], [0, 1, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]])
  186. out = net(x)
  187. print(out)
  188. def test_probability_cauchy():
  189. """
  190. Feature: JIT Fallback
  191. Description: NumPy method is called in probability cauchy.
  192. Expectation: No exception.
  193. """
  194. class CauchyProb(nn.Cell):
  195. def __init__(self, loc, scale, seed=10, dtype=mstype.float32, name='Cauchy'):
  196. super().__init__()
  197. self.b = distribution.Cauchy(loc, scale, seed, dtype, name)
  198. def construct(self, value, loc=None, scale=None):
  199. out1 = self.b.prob(value, loc, scale)
  200. out2 = self.b.log_prob(value, loc, scale)
  201. out3 = self.b.cdf(value, loc, scale)
  202. out4 = self.b.log_cdf(value, loc, scale)
  203. out5 = self.b.survival_function(value, loc, scale)
  204. out6 = self.b.log_survival(value, loc, scale)
  205. return out1, out2, out3, out4, out5, out6
  206. loc = np.random.randn(1024, 512, 7, 7).astype(np.float32)
  207. scale = np.random.uniform(0.0001, 100, size=(1024, 512, 7, 7)).astype(np.float32)
  208. loc_a = np.random.randn(1024, 512, 7, 7).astype(np.float32)
  209. scale_a = np.random.uniform(0.0001, 100, size=(1024, 512, 7, 7)).astype(np.float32)
  210. value = np.random.randn(1024, 512, 7, 7).astype(np.float32)
  211. net = CauchyProb(loc, scale)
  212. net(Tensor(value), Tensor(loc_a), Tensor(scale_a))
  213. def test_third_party_module_functools():
  214. """
  215. Feature: JIT Fallback
  216. Description: functools is a python built-in module and does not perform JIT Fallback.
  217. Expectation: No exception.
  218. """
  219. class ModuleNet(nn.Cell):
  220. def construct(self, x, y):
  221. func = functools.partial(add_func, x)
  222. out = func(y)
  223. return out
  224. x = Tensor([1, 2, 3], mstype.int32)
  225. y = Tensor([4, 5, 6], mstype.int32)
  226. net = ModuleNet()
  227. out = net(x, y)
  228. print(out)