You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graph_fallback.py 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test graph fallback """
  16. import pytest
  17. import numpy as np
  18. import mindspore.nn as nn
  19. from mindspore import Tensor, ms_function, context
  20. from mindspore.ops import operations as P
  21. from mindspore.ops import functional as F
  22. from mindspore.nn.probability import distribution
  23. import mindspore.common.dtype as mstype
  24. import mindspore.common._monad as monad
  25. context.set_context(mode=context.GRAPH_MODE)
  26. # `add_func` is defined in current file.
  27. def add_func(x, y):
  28. return x + y
  29. @ms_function
  30. def do_increment(i):
  31. add_1 = F.partial(add_func, 1)
  32. return add_1(i)
  33. def test_increment():
  34. a = do_increment(9)
  35. assert a == 10
  36. @ms_function
  37. def use_monad(x, y):
  38. res = P.Mul()(x, y)
  39. res = F.depend(res, monad.U)
  40. return res
  41. def test_use_monad():
  42. x = Tensor(1.0, mstype.float32)
  43. y = Tensor(1.0, mstype.float32)
  44. print(use_monad(x, y))
  45. @ms_function
  46. def use_tuple_of_tensor():
  47. me_x = (Tensor(1), Tensor(1))
  48. return me_x
  49. def test_tuple_of_tensor():
  50. """
  51. Feature: JIT Fallback
  52. Description: Test tuple of tensor in graph mode.
  53. Expectation: No exception.
  54. """
  55. print(use_tuple_of_tensor())
  56. @ms_function
  57. def use_list_of_tensor():
  58. me_x = [Tensor(1), Tensor(1)]
  59. return me_x
  60. def test_list_of_tensor():
  61. """
  62. Feature: JIT Fallback
  63. Description: Test list of tensor in graph mode.
  64. Expectation: No exception.
  65. """
  66. print(use_list_of_tensor())
  67. class Net(nn.Cell):
  68. def __init__(self):
  69. super(Net, self).__init__()
  70. self.x = Tensor([2, 3, 4])
  71. def construct(self):
  72. x_len = len(self.x)
  73. for i in range(x_len):
  74. print(i)
  75. return x_len
  76. def test_builtins_len():
  77. net = Net()
  78. net()
  79. @ms_function
  80. def np_fallback_func():
  81. array_x = tuple([2, 3, 4, 5])
  82. np_x = np.array(array_x).astype(np.float32)
  83. me_x = Tensor(np_x)
  84. me_x = me_x + me_x
  85. return me_x
  86. def test_np_fallback_func():
  87. print(np_fallback_func())
  88. # Test `return` interpret node.
  89. @ms_function
  90. def div_mod_func1():
  91. x = 8
  92. y = 3
  93. a = divmod(x, y)
  94. return Tensor(a)
  95. def test_div_mod_func1():
  96. print(div_mod_func1()) # (2, 2)
  97. # Test interpret node with parameters as input.
  98. @ms_function
  99. def div_mod_func2(x, y):
  100. a = divmod(x, y)
  101. return Tensor(a)
  102. def test_div_mod_func2_scalar():
  103. """
  104. Feature: JIT Fallback
  105. Description: Test divmod in graph.
  106. Expectation: No exception.
  107. """
  108. print(div_mod_func2(8, 3)) # (2, 2)
  109. @pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
  110. def test_div_mod_func2_tensor():
  111. """
  112. Feature: JIT Fallback
  113. Description: Test divmod with Tensor input in graph. We'll support it in Tensor Input Fallback solution.
  114. Expectation: Not supported exception.
  115. """
  116. with pytest.raises(RuntimeError) as err:
  117. print(div_mod_func2(Tensor(8), Tensor(3)))
  118. assert "Not support Tensor or variable type as input during running JIT Fallback, but got" in str(err.value)
  119. @ms_function
  120. def select_func(cond, x, y):
  121. if isinstance(cond, (tuple, list)):
  122. output = y
  123. elif isinstance(cond, Tensor):
  124. output = F.select(cond, x, y)
  125. else:
  126. output = x
  127. return output
  128. def test_select_func():
  129. cond = Tensor([True, False])
  130. x = Tensor([2, 3], mstype.float32)
  131. y = Tensor([1, 2], mstype.float32)
  132. print(select_func(cond, x, y))
  133. @ms_function
  134. def select_func2(cond, x, y):
  135. if isinstance(cond, (tuple, list)):
  136. output = y
  137. if isinstance(cond, Tensor):
  138. output = F.select(cond, x, y)
  139. else:
  140. output = x
  141. return output
  142. def test_select_func2():
  143. cond = Tensor([True, False])
  144. x = Tensor([2, 3], mstype.float32)
  145. y = Tensor([1, 2], mstype.float32)
  146. print(select_func2(cond, x, y))
  147. @ms_function
  148. def slice_func(a, b):
  149. a[1:3, ::] = b
  150. return a
  151. def test_slice_func():
  152. a = Tensor(np.arange(60).reshape(3, 4, 5), dtype=mstype.float32)
  153. b = Tensor([1], dtype=mstype.float32)
  154. print(slice_func(a, b))
  155. def test_context():
  156. """
  157. Feature: JIT Fallback
  158. Description: Test context in graph.
  159. Expectation: No exception.
  160. """
  161. class ContextNet(nn.Cell):
  162. def __init__(self):
  163. super(ContextNet, self).__init__()
  164. self.mode = context.get_context("mode")
  165. def construct(self):
  166. out = 1
  167. if self.mode == context.GRAPH_MODE:
  168. out = 2
  169. return out
  170. net = ContextNet()
  171. out = net()
  172. print(out)
  173. def test_self_attr():
  174. """
  175. Feature: JIT Fallback
  176. Description: Test self.attr in graph.
  177. Expectation: No exception.
  178. """
  179. class Network(nn.Cell):
  180. def __init__(self):
  181. super(Network, self).__init__()
  182. self.dim = 1
  183. def construct(self, x):
  184. batch = x.shape[0]
  185. one = Tensor(np.ones([batch, self.dim]), mstype.float16)
  186. return one * x
  187. net = Network()
  188. x = Tensor([1, 2], mstype.float32)
  189. out = net(x)
  190. print(out)
  191. def test_self_attr_2():
  192. """
  193. Feature: JIT Fallback
  194. Description: Test self.attr in graph.
  195. Expectation: No exception.
  196. """
  197. class Network(nn.Cell):
  198. def __init__(self, fn):
  199. super(Network, self).__init__()
  200. self.fn = fn
  201. def construct(self):
  202. x = np.array([1, 2, 3])
  203. y = np.array([3, 4, 5])
  204. out = Tensor(self.fn(x, y))
  205. return out
  206. def fn(x, y):
  207. return x + y
  208. net = Network(fn)
  209. out = net()
  210. print(out)
  211. def test_self_attr_3():
  212. """
  213. Feature: JIT Fallback
  214. Description: Test self.attr in graph.
  215. Expectation: No exception.
  216. """
  217. class Network(nn.Cell):
  218. def __init__(self):
  219. super(Network, self).__init__()
  220. self.value = [2, 2, 3]
  221. def construct(self):
  222. x = np.array(self.value.count(2))
  223. return Tensor(x)
  224. net = Network()
  225. out = net()
  226. print(out)
  227. def test_self_method():
  228. """
  229. Feature: JIT Fallback
  230. Description: Test self.method in graph.
  231. Expectation: No exception.
  232. """
  233. class Network(nn.Cell):
  234. def construct(self):
  235. x = np.array([1, 2, 3])
  236. y = np.array([3, 4, 5])
  237. out = Tensor(self.fn(x, y))
  238. return out
  239. def fn(self, x, y):
  240. return x + y
  241. net = Network()
  242. out = net()
  243. print(out)
  244. @pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
  245. def test_self_method_2():
  246. """
  247. Feature: JIT Fallback
  248. Description: Test self.method in graph.
  249. Expectation: No exception.
  250. """
  251. class Network(nn.Cell):
  252. def construct(self):
  253. x = np.array([1, 2, 3])
  254. y = np.array([3, 4, 5])
  255. z = self.fn(x, y)
  256. out = Tensor(z)
  257. return out
  258. def fn(self, x, y):
  259. return x + y
  260. net = Network()
  261. out = net()
  262. print(out)
  263. def test_probability_cauchy():
  264. """
  265. Feature: JIT Fallback
  266. Description: NumPy method is called in probability cauchy.
  267. Expectation: No exception.
  268. """
  269. class CauchyProb(nn.Cell):
  270. def __init__(self, loc, scale, seed=10, dtype=mstype.float32, name='Cauchy'):
  271. super().__init__()
  272. self.b = distribution.Cauchy(loc, scale, seed, dtype, name)
  273. def construct(self, value, loc=None, scale=None):
  274. out1 = self.b.prob(value, loc, scale)
  275. out2 = self.b.log_prob(value, loc, scale)
  276. out3 = self.b.cdf(value, loc, scale)
  277. out4 = self.b.log_cdf(value, loc, scale)
  278. out5 = self.b.survival_function(value, loc, scale)
  279. out6 = self.b.log_survival(value, loc, scale)
  280. return out1, out2, out3, out4, out5, out6
  281. loc = np.random.randn(1024, 512, 7, 7).astype(np.float32)
  282. scale = np.random.uniform(0.0001, 100, size=(1024, 512, 7, 7)).astype(np.float32)
  283. loc_a = np.random.randn(1024, 512, 7, 7).astype(np.float32)
  284. scale_a = np.random.uniform(0.0001, 100, size=(1024, 512, 7, 7)).astype(np.float32)
  285. value = np.random.randn(1024, 512, 7, 7).astype(np.float32)
  286. net = CauchyProb(loc, scale)
  287. net(Tensor(value), Tensor(loc_a), Tensor(scale_a))