You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graph_fallback.py 9.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test graph fallback """
  16. import pytest
  17. import numpy as np
  18. import mindspore.nn as nn
  19. from mindspore import Tensor, ms_function, context
  20. from mindspore.ops import operations as P
  21. from mindspore.ops import functional as F
  22. from mindspore.nn.probability import distribution
  23. import mindspore.common.dtype as mstype
  24. import mindspore.common._monad as monad
  25. import mindspore.scipy.linalg as alg
  26. context.set_context(mode=context.GRAPH_MODE)
  27. # `add_func` is defined in current file.
  28. def add_func(x, y):
  29. return x + y
  30. @ms_function
  31. def do_increment(i):
  32. add_1 = F.partial(add_func, 1)
  33. return add_1(i)
  34. def test_increment():
  35. a = do_increment(9)
  36. assert a == 10
  37. @ms_function
  38. def use_monad(x, y):
  39. res = P.Mul()(x, y)
  40. res = F.depend(res, monad.U)
  41. return res
  42. def test_use_monad():
  43. x = Tensor(1.0, mstype.float32)
  44. y = Tensor(1.0, mstype.float32)
  45. print(use_monad(x, y))
  46. @ms_function
  47. def use_tuple_of_tensor():
  48. me_x = (Tensor(1), Tensor(1))
  49. return me_x
  50. def test_tuple_of_tensor():
  51. """
  52. Feature: JIT Fallback
  53. Description: Test tuple of tensor in graph mode.
  54. Expectation: No exception.
  55. """
  56. print(use_tuple_of_tensor())
  57. @ms_function
  58. def use_list_of_tensor():
  59. me_x = [Tensor(1), Tensor(1)]
  60. return me_x
  61. def test_list_of_tensor():
  62. """
  63. Feature: JIT Fallback
  64. Description: Test list of tensor in graph mode.
  65. Expectation: No exception.
  66. """
  67. print(use_list_of_tensor())
  68. class Net(nn.Cell):
  69. def __init__(self):
  70. super(Net, self).__init__()
  71. self.x = Tensor([2, 3, 4])
  72. def construct(self):
  73. x_len = len(self.x)
  74. for i in range(x_len):
  75. print(i)
  76. return x_len
  77. def test_builtins_len():
  78. net = Net()
  79. net()
  80. @ms_function
  81. def np_fallback_func():
  82. array_x = tuple([2, 3, 4, 5])
  83. np_x = np.array(array_x).astype(np.float32)
  84. me_x = Tensor(np_x)
  85. me_x = me_x + me_x
  86. return me_x
  87. def test_np_fallback_func():
  88. print(np_fallback_func())
  89. # Test `return` interpret node.
  90. @ms_function
  91. def div_mod_func1():
  92. x = 8
  93. y = 3
  94. a = divmod(x, y)
  95. return Tensor(a)
  96. def test_div_mod_func1():
  97. print(div_mod_func1()) # (2, 2)
  98. # Test interpret node with parameters as input.
  99. @ms_function
  100. def div_mod_func2(x, y):
  101. a = divmod(x, y)
  102. return Tensor(a)
  103. def test_div_mod_func2_scalar():
  104. """
  105. Feature: JIT Fallback
  106. Description: Test divmod in graph.
  107. Expectation: No exception.
  108. """
  109. print(div_mod_func2(8, 3)) # (2, 2)
  110. @pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
  111. def test_div_mod_func2_tensor():
  112. """
  113. Feature: JIT Fallback
  114. Description: Test divmod with Tensor input in graph. We'll support it in Tensor Input Fallback solution.
  115. Expectation: Not supported exception.
  116. """
  117. with pytest.raises(RuntimeError) as err:
  118. print(div_mod_func2(Tensor(8), Tensor(3)))
  119. assert "Not support Tensor or variable type as input during running JIT Fallback, but got" in str(err.value)
  120. @ms_function
  121. def select_func(cond, x, y):
  122. if isinstance(cond, (tuple, list)):
  123. output = y
  124. elif isinstance(cond, Tensor):
  125. output = F.select(cond, x, y)
  126. else:
  127. output = x
  128. return output
  129. def test_select_func():
  130. cond = Tensor([True, False])
  131. x = Tensor([2, 3], mstype.float32)
  132. y = Tensor([1, 2], mstype.float32)
  133. print(select_func(cond, x, y))
  134. @ms_function
  135. def select_func2(cond, x, y):
  136. if isinstance(cond, (tuple, list)):
  137. output = y
  138. if isinstance(cond, Tensor):
  139. output = F.select(cond, x, y)
  140. else:
  141. output = x
  142. return output
  143. def test_select_func2():
  144. cond = Tensor([True, False])
  145. x = Tensor([2, 3], mstype.float32)
  146. y = Tensor([1, 2], mstype.float32)
  147. print(select_func2(cond, x, y))
  148. @ms_function
  149. def slice_func(a, b):
  150. a[1:3, ::] = b
  151. return a
  152. def test_slice_func():
  153. a = Tensor(np.arange(60).reshape(3, 4, 5), dtype=mstype.float32)
  154. b = Tensor([1], dtype=mstype.float32)
  155. print(slice_func(a, b))
  156. def test_context():
  157. """
  158. Feature: JIT Fallback
  159. Description: Test context in graph.
  160. Expectation: No exception.
  161. """
  162. class ContextNet(nn.Cell):
  163. def __init__(self):
  164. super(ContextNet, self).__init__()
  165. self.mode = context.get_context("mode")
  166. def construct(self):
  167. out = 1
  168. if self.mode == context.GRAPH_MODE:
  169. out = 2
  170. return out
  171. net = ContextNet()
  172. out = net()
  173. print(out)
  174. def test_scipy_module():
  175. """
  176. Feature: JIT Fallback
  177. Description: Test scipy module in graph.
  178. Expectation: No exception.
  179. """
  180. class Network(nn.Cell):
  181. def construct(self, x):
  182. return alg.eigh(x)
  183. net = Network()
  184. x = Tensor([[2, 0, 0, 0], [0, 1, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]])
  185. out = net(x)
  186. print(out)
  187. def test_self_attr():
  188. """
  189. Feature: JIT Fallback
  190. Description: Test self.attr in graph.
  191. Expectation: No exception.
  192. """
  193. class Network(nn.Cell):
  194. def __init__(self):
  195. super(Network, self).__init__()
  196. self.dim = 1
  197. def construct(self, x):
  198. batch = x.shape[0]
  199. one = Tensor(np.ones([batch, self.dim]), mstype.float16)
  200. return one * x
  201. net = Network()
  202. x = Tensor([1, 2], mstype.float32)
  203. out = net(x)
  204. print(out)
  205. def test_self_attr_2():
  206. """
  207. Feature: JIT Fallback
  208. Description: Test self.attr in graph.
  209. Expectation: No exception.
  210. """
  211. class Network(nn.Cell):
  212. def __init__(self, fn):
  213. super(Network, self).__init__()
  214. self.fn = fn
  215. def construct(self):
  216. x = np.array([1, 2, 3])
  217. y = np.array([3, 4, 5])
  218. out = Tensor(self.fn(x, y))
  219. return out
  220. def fn(x, y):
  221. return x + y
  222. net = Network(fn)
  223. out = net()
  224. print(out)
  225. def test_self_attr_3():
  226. """
  227. Feature: JIT Fallback
  228. Description: Test self.attr in graph.
  229. Expectation: No exception.
  230. """
  231. class Network(nn.Cell):
  232. def __init__(self):
  233. super(Network, self).__init__()
  234. self.value = [2, 2, 3]
  235. def construct(self):
  236. x = np.array(self.value.count(2))
  237. return Tensor(x)
  238. net = Network()
  239. out = net()
  240. print(out)
  241. def test_self_method():
  242. """
  243. Feature: JIT Fallback
  244. Description: Test self.method in graph.
  245. Expectation: No exception.
  246. """
  247. class Network(nn.Cell):
  248. def construct(self):
  249. x = np.array([1, 2, 3])
  250. y = np.array([3, 4, 5])
  251. out = Tensor(self.fn(x, y))
  252. return out
  253. def fn(self, x, y):
  254. return x + y
  255. net = Network()
  256. out = net()
  257. print(out)
  258. @pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
  259. def test_self_method_2():
  260. """
  261. Feature: JIT Fallback
  262. Description: Test self.method in graph.
  263. Expectation: No exception.
  264. """
  265. class Network(nn.Cell):
  266. def construct(self):
  267. x = np.array([1, 2, 3])
  268. y = np.array([3, 4, 5])
  269. z = self.fn(x, y)
  270. out = Tensor(z)
  271. return out
  272. def fn(self, x, y):
  273. return x + y
  274. net = Network()
  275. out = net()
  276. print(out)
  277. def test_probability_cauchy():
  278. """
  279. Feature: JIT Fallback
  280. Description: NumPy method is called in probability cauchy.
  281. Expectation: No exception.
  282. """
  283. class CauchyProb(nn.Cell):
  284. def __init__(self, loc, scale, seed=10, dtype=mstype.float32, name='Cauchy'):
  285. super().__init__()
  286. self.b = distribution.Cauchy(loc, scale, seed, dtype, name)
  287. def construct(self, value, loc=None, scale=None):
  288. out1 = self.b.prob(value, loc, scale)
  289. out2 = self.b.log_prob(value, loc, scale)
  290. out3 = self.b.cdf(value, loc, scale)
  291. out4 = self.b.log_cdf(value, loc, scale)
  292. out5 = self.b.survival_function(value, loc, scale)
  293. out6 = self.b.log_survival(value, loc, scale)
  294. return out1, out2, out3, out4, out5, out6
  295. loc = np.random.randn(1024, 512, 7, 7).astype(np.float32)
  296. scale = np.random.uniform(0.0001, 100, size=(1024, 512, 7, 7)).astype(np.float32)
  297. loc_a = np.random.randn(1024, 512, 7, 7).astype(np.float32)
  298. scale_a = np.random.uniform(0.0001, 100, size=(1024, 512, 7, 7)).astype(np.float32)
  299. value = np.random.randn(1024, 512, 7, 7).astype(np.float32)
  300. net = CauchyProb(loc, scale)
  301. net(Tensor(value), Tensor(loc_a), Tensor(scale_a))