You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graph_fallback_class.py 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test graph fallback """
  16. import pytest
  17. import numpy as np
  18. import mindspore.nn as nn
  19. import mindspore.common.dtype as mstype
  20. from mindspore import Tensor, context, ms_class, ms_function
  21. from . import test_graph_fallback
  22. context.set_context(mode=context.GRAPH_MODE)
  23. def test_fallback_self_attr():
  24. """
  25. Feature: JIT Fallback
  26. Description: Test self.attr in graph.
  27. Expectation: No exception.
  28. """
  29. class Network(nn.Cell):
  30. def __init__(self):
  31. super(Network, self).__init__()
  32. self.dim = 1
  33. def construct(self, x):
  34. batch = x.shape[0]
  35. one = Tensor(np.ones([batch, self.dim]), mstype.float32)
  36. return one * x
  37. net = Network()
  38. x = Tensor([1, 2], mstype.float32)
  39. out = net(x)
  40. expect = np.array([[1., 2.], [1., 2.]])
  41. assert np.allclose(out.asnumpy(), expect, 1.e-2, 1.e-2)
  42. def test_fallback_self_attr_fn():
  43. """
  44. Feature: JIT Fallback
  45. Description: Test self.attr in graph.
  46. Expectation: No exception.
  47. """
  48. class Network(nn.Cell):
  49. def __init__(self, fn):
  50. super(Network, self).__init__()
  51. self.fn = fn
  52. def construct(self):
  53. x = np.array([1, 2, 3])
  54. y = np.array([3, 4, 5])
  55. out = Tensor(self.fn(x, y))
  56. return out
  57. def fn(x, y):
  58. return x + y
  59. net = Network(fn)
  60. out = net()
  61. expect = np.array([4, 6, 8])
  62. assert np.all(out.asnumpy() == expect)
  63. def test_fallback_self_attr_attr():
  64. """
  65. Feature: JIT Fallback
  66. Description: Test self.attr in graph.
  67. Expectation: No exception.
  68. """
  69. class Network(nn.Cell):
  70. def __init__(self):
  71. super(Network, self).__init__()
  72. self.value = [2, 2, 3]
  73. def construct(self):
  74. x = np.array(self.value.count(2))
  75. return Tensor(x)
  76. net = Network()
  77. out = net()
  78. assert out == 2
  79. def test_fallback_self_method():
  80. """
  81. Feature: JIT Fallback
  82. Description: Test self.method in graph.
  83. Expectation: No exception.
  84. """
  85. class Network(nn.Cell):
  86. def construct(self):
  87. x = np.array([1, 2, 3])
  88. y = np.array([3, 4, 5])
  89. out = Tensor(self.fn(x, y))
  90. return out
  91. def fn(self, x, y):
  92. return x + y
  93. net = Network()
  94. out = net()
  95. expect = np.array([4, 6, 8])
  96. assert np.all(out.asnumpy() == expect)
  97. @pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
  98. def test_fallback_self_method_tensor():
  99. """
  100. Feature: JIT Fallback
  101. Description: Test self.method in graph.
  102. Expectation: No exception.
  103. """
  104. class Network(nn.Cell):
  105. def construct(self):
  106. x = np.array([1, 2, 3])
  107. y = np.array([3, 4, 5])
  108. z = self.fn(x, y)
  109. out = Tensor(z)
  110. return out
  111. def fn(self, x, y):
  112. return x + y
  113. net = Network()
  114. out = net()
  115. print(out)
  116. def test_fallback_import_modules():
  117. """
  118. Feature: JIT Fallback
  119. Description: add_func is defined in test_graph_fallback.py
  120. Expectation: No exception.
  121. """
  122. @ms_function
  123. def use_imported_module(x, y):
  124. out = test_graph_fallback.add_func(x, y)
  125. return out
  126. x = Tensor(2, dtype=mstype.int32)
  127. y = Tensor(3, dtype=mstype.int32)
  128. out = use_imported_module(x, y)
  129. print(out)
  130. def test_fallback_class_attr():
  131. """
  132. Feature: JIT Fallback
  133. Description: Test user-defined class attributes in graph.
  134. Expectation: No exception.
  135. """
  136. @ms_class
  137. class InnerNet:
  138. def __init__(self):
  139. self.number = 1
  140. class Net(nn.Cell):
  141. def __init__(self):
  142. super(Net, self).__init__()
  143. self.inner_net = InnerNet()
  144. def construct(self):
  145. out = self.inner_net.number
  146. return out
  147. net = Net()
  148. out = net()
  149. assert out == 1
  150. def test_fallback_class_method():
  151. """
  152. Feature: JIT Fallback
  153. Description: Test user-defined class methods in graph.
  154. Expectation: No exception.
  155. """
  156. @ms_class
  157. class InnerNet:
  158. def __init__(self):
  159. self.val = 2
  160. def act(self, x, y):
  161. return self.val * (x + y)
  162. class Net(nn.Cell):
  163. def __init__(self):
  164. super(Net, self).__init__()
  165. self.inner_net = InnerNet()
  166. def construct(self):
  167. out = self.inner_net.act(1, 2)
  168. return out
  169. net = Net()
  170. out = net()
  171. assert out == 6
  172. def test_fallback_class_input_attr():
  173. """
  174. Feature: JIT Fallback
  175. Description: Test user-defined class attributes in graph.
  176. Expectation: No exception.
  177. """
  178. @ms_class
  179. class InnerNet:
  180. def __init__(self):
  181. self.number = Tensor(np.array([1, 2, 3]))
  182. class Net(nn.Cell):
  183. def __init__(self, net):
  184. super(Net, self).__init__()
  185. self.inner_net = net()
  186. def construct(self):
  187. out = self.inner_net.number
  188. return out
  189. net = Net(InnerNet)
  190. out = net()
  191. expect_res = np.array([1, 2, 3])
  192. assert np.all(out.asnumpy() == expect_res)
  193. def test_fallback_class_input_method():
  194. """
  195. Feature: JIT Fallback
  196. Description: Test user-defined class methods in graph.
  197. Expectation: No exception.
  198. """
  199. @ms_class
  200. class InnerNet:
  201. def __init__(self):
  202. self.val = 2
  203. def act(self, x, y):
  204. return self.val * (x + y)
  205. class Net(nn.Cell):
  206. def __init__(self, net):
  207. super(Net, self).__init__()
  208. self.inner_net = net()
  209. def construct(self):
  210. out = self.inner_net.act(1, 2)
  211. return out
  212. net = Net(InnerNet)
  213. out = net()
  214. assert out == 6
  215. def test_fallback_class_class_nested():
  216. """
  217. Feature: JIT Fallback
  218. Description: Test nested ms_class in graph.
  219. Expectation: No exception.
  220. """
  221. @ms_class
  222. class Inner:
  223. def __init__(self):
  224. self.number = 1
  225. @ms_class
  226. class InnerNet:
  227. def __init__(self):
  228. self.inner = Inner()
  229. class Net(nn.Cell):
  230. def __init__(self):
  231. super(Net, self).__init__()
  232. self.inner_net = InnerNet()
  233. def construct(self):
  234. out = self.inner_net.inner.number
  235. return out
  236. net = Net()
  237. out = net()
  238. assert out == 1
  239. def test_fallback_class_cell_nested():
  240. """
  241. Feature: JIT Fallback
  242. Description: Test nested ms_class and cell in graph.
  243. Expectation: No exception.
  244. """
  245. class Net(nn.Cell):
  246. def __init__(self, val):
  247. super().__init__()
  248. self.val = val
  249. def construct(self, x):
  250. return x + self.val
  251. @ms_class
  252. class TrainNet():
  253. class Loss(nn.Cell):
  254. def __init__(self, net):
  255. super().__init__()
  256. self.net = net
  257. def construct(self, x):
  258. out = self.net(x)
  259. return out * 2
  260. def __init__(self, net):
  261. self.net = net
  262. loss_net = self.Loss(self.net)
  263. self.number = loss_net(10)
  264. global_net = Net(1)
  265. class LearnNet(nn.Cell):
  266. def __init__(self):
  267. super().__init__()
  268. self.value = TrainNet(global_net).number
  269. def construct(self, x):
  270. return x + self.value
  271. leanrn_net = LearnNet()
  272. out = leanrn_net(3)
  273. print(out)
  274. assert out == 25
  275. @pytest.mark.skip(reason='Not support in graph yet')
  276. def test_fallback_class_isinstance():
  277. """
  278. Feature: JIT Fallback
  279. Description: Test ms_class in graph.
  280. Expectation: No exception.
  281. """
  282. @ms_class
  283. class InnerNet:
  284. def __init__(self):
  285. self.number = 1
  286. class Net(nn.Cell):
  287. def __init__(self):
  288. super(Net, self).__init__()
  289. self.inner_net = InnerNet()
  290. def construct(self, x):
  291. if isinstance(self.inner_net, InnerNet):
  292. return x + 10
  293. return x
  294. net = Net()
  295. out = net(5)
  296. assert out == 15
  297. def test_fallback_raise_error_not_class_type():
  298. """
  299. Feature: JIT Fallback
  300. Description: Test ms_class in graph.
  301. Expectation: No exception.
  302. """
  303. with pytest.raises(TypeError):
  304. @ms_class
  305. def func(x, y):
  306. return x + y
  307. func(1, 2)
  308. def test_fallback_raise_error_not_class_instance():
  309. """
  310. Feature: JIT Fallback
  311. Description: Test ms_class in graph.
  312. Expectation: No exception.
  313. """
  314. @ms_class
  315. class InnerNet:
  316. def __init__(self):
  317. self.number = 1
  318. class Net(nn.Cell):
  319. def construct(self):
  320. out = InnerNet().number
  321. return out
  322. with pytest.raises(ValueError):
  323. net = Net()
  324. net()
  325. def test_fallback_raise_error_decorate_cell():
  326. """
  327. Feature: JIT Fallback
  328. Description: Test ms_class in graph.
  329. Expectation: No exception.
  330. """
  331. @ms_class
  332. class Net(nn.Cell):
  333. def construct(self, x):
  334. return x
  335. with pytest.raises(TypeError):
  336. x = Tensor(1)
  337. net = Net()
  338. net(x)