You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graph_fallback_class.py 9.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test graph fallback """
  16. import pytest
  17. import numpy as np
  18. import mindspore.nn as nn
  19. import mindspore.common.dtype as mstype
  20. from mindspore import Tensor, context, ms_class
  21. context.set_context(mode=context.GRAPH_MODE)
  22. def test_fallback_self_attr():
  23. """
  24. Feature: JIT Fallback
  25. Description: Test self.attr in graph.
  26. Expectation: No exception.
  27. """
  28. class Network(nn.Cell):
  29. def __init__(self):
  30. super(Network, self).__init__()
  31. self.dim = 1
  32. def construct(self, x):
  33. batch = x.shape[0]
  34. one = Tensor(np.ones([batch, self.dim]), mstype.float32)
  35. return one * x
  36. net = Network()
  37. x = Tensor([1, 2], mstype.float32)
  38. out = net(x)
  39. expect = np.array([[1., 2.], [1., 2.]])
  40. assert np.allclose(out.asnumpy(), expect, 1.e-2, 1.e-2)
  41. def test_fallback_self_attr_fn():
  42. """
  43. Feature: JIT Fallback
  44. Description: Test self.attr in graph.
  45. Expectation: No exception.
  46. """
  47. class Network(nn.Cell):
  48. def __init__(self, fn):
  49. super(Network, self).__init__()
  50. self.fn = fn
  51. def construct(self):
  52. x = np.array([1, 2, 3])
  53. y = np.array([3, 4, 5])
  54. out = Tensor(self.fn(x, y))
  55. return out
  56. def fn(x, y):
  57. return x + y
  58. net = Network(fn)
  59. out = net()
  60. expect = np.array([4, 6, 8])
  61. assert np.all(out.asnumpy() == expect)
  62. def test_fallback_self_attr_attr():
  63. """
  64. Feature: JIT Fallback
  65. Description: Test self.attr in graph.
  66. Expectation: No exception.
  67. """
  68. class Network(nn.Cell):
  69. def __init__(self):
  70. super(Network, self).__init__()
  71. self.value = [2, 2, 3]
  72. def construct(self):
  73. x = np.array(self.value.count(2))
  74. return Tensor(x)
  75. net = Network()
  76. out = net()
  77. assert out == 2
  78. def test_fallback_self_method():
  79. """
  80. Feature: JIT Fallback
  81. Description: Test self.method in graph.
  82. Expectation: No exception.
  83. """
  84. class Network(nn.Cell):
  85. def construct(self):
  86. x = np.array([1, 2, 3])
  87. y = np.array([3, 4, 5])
  88. out = Tensor(self.fn(x, y))
  89. return out
  90. def fn(self, x, y):
  91. return x + y
  92. net = Network()
  93. out = net()
  94. expect = np.array([4, 6, 8])
  95. assert np.all(out.asnumpy() == expect)
  96. @pytest.mark.skip(reason='Not support in graph jit fallback feature yet')
  97. def test_fallback_self_method_tensor():
  98. """
  99. Feature: JIT Fallback
  100. Description: Test self.method in graph.
  101. Expectation: No exception.
  102. """
  103. class Network(nn.Cell):
  104. def construct(self):
  105. x = np.array([1, 2, 3])
  106. y = np.array([3, 4, 5])
  107. z = self.fn(x, y)
  108. out = Tensor(z)
  109. return out
  110. def fn(self, x, y):
  111. return x + y
  112. net = Network()
  113. out = net()
  114. print(out)
  115. def test_fallback_class_attr():
  116. """
  117. Feature: JIT Fallback
  118. Description: Test user-defined class attributes in graph.
  119. Expectation: No exception.
  120. """
  121. @ms_class
  122. class InnerNet:
  123. def __init__(self):
  124. self.number = 1
  125. class Net(nn.Cell):
  126. def __init__(self):
  127. super(Net, self).__init__()
  128. self.inner_net = InnerNet()
  129. def construct(self):
  130. out = self.inner_net.number
  131. return out
  132. net = Net()
  133. out = net()
  134. assert out == 1
  135. def test_fallback_class_method():
  136. """
  137. Feature: JIT Fallback
  138. Description: Test user-defined class methods in graph.
  139. Expectation: No exception.
  140. """
  141. @ms_class
  142. class InnerNet:
  143. def __init__(self):
  144. self.val = 2
  145. def act(self, x, y):
  146. return self.val * (x + y)
  147. class Net(nn.Cell):
  148. def __init__(self):
  149. super(Net, self).__init__()
  150. self.inner_net = InnerNet()
  151. def construct(self):
  152. out = self.inner_net.act(1, 2)
  153. return out
  154. net = Net()
  155. out = net()
  156. assert out == 6
  157. def test_fallback_class_input_attr():
  158. """
  159. Feature: JIT Fallback
  160. Description: Test user-defined class attributes in graph.
  161. Expectation: No exception.
  162. """
  163. @ms_class
  164. class InnerNet:
  165. def __init__(self):
  166. self.number = Tensor(np.array([1, 2, 3]))
  167. class Net(nn.Cell):
  168. def __init__(self, net):
  169. super(Net, self).__init__()
  170. self.inner_net = net()
  171. def construct(self):
  172. out = self.inner_net.number
  173. return out
  174. net = Net(InnerNet)
  175. out = net()
  176. expect_res = np.array([1, 2, 3])
  177. assert np.all(out.asnumpy() == expect_res)
  178. def test_fallback_class_input_method():
  179. """
  180. Feature: JIT Fallback
  181. Description: Test user-defined class methods in graph.
  182. Expectation: No exception.
  183. """
  184. @ms_class
  185. class InnerNet:
  186. def __init__(self):
  187. self.val = 2
  188. def act(self, x, y):
  189. return self.val * (x + y)
  190. class Net(nn.Cell):
  191. def __init__(self, net):
  192. super(Net, self).__init__()
  193. self.inner_net = net()
  194. def construct(self):
  195. out = self.inner_net.act(1, 2)
  196. return out
  197. net = Net(InnerNet)
  198. out = net()
  199. assert out == 6
  200. def test_fallback_class_class_nested():
  201. """
  202. Feature: JIT Fallback
  203. Description: Test nested ms_class in graph.
  204. Expectation: No exception.
  205. """
  206. @ms_class
  207. class Inner:
  208. def __init__(self):
  209. self.number = 1
  210. @ms_class
  211. class InnerNet:
  212. def __init__(self):
  213. self.inner = Inner()
  214. class Net(nn.Cell):
  215. def __init__(self):
  216. super(Net, self).__init__()
  217. self.inner_net = InnerNet()
  218. def construct(self):
  219. out = self.inner_net.inner.number
  220. return out
  221. net = Net()
  222. out = net()
  223. assert out == 1
  224. def test_fallback_class_cell_nested():
  225. """
  226. Feature: JIT Fallback
  227. Description: Test nested ms_class and cell in graph.
  228. Expectation: No exception.
  229. """
  230. class Net(nn.Cell):
  231. def __init__(self, val):
  232. super().__init__()
  233. self.val = val
  234. def construct(self, x):
  235. return x + self.val
  236. @ms_class
  237. class TrainNet():
  238. class Loss(nn.Cell):
  239. def __init__(self, net):
  240. super().__init__()
  241. self.net = net
  242. def construct(self, x):
  243. out = self.net(x)
  244. return out * 2
  245. def __init__(self, net):
  246. self.net = net
  247. loss_net = self.Loss(self.net)
  248. self.number = loss_net(10)
  249. global_net = Net(1)
  250. class LearnNet(nn.Cell):
  251. def __init__(self):
  252. super().__init__()
  253. self.value = TrainNet(global_net).number
  254. def construct(self, x):
  255. return x + self.value
  256. leanrn_net = LearnNet()
  257. out = leanrn_net(3)
  258. print(out)
  259. assert out == 25
  260. @pytest.mark.skip(reason='Not support in graph yet')
  261. def test_fallback_class_isinstance():
  262. """
  263. Feature: JIT Fallback
  264. Description: Test ms_class in graph.
  265. Expectation: No exception.
  266. """
  267. @ms_class
  268. class InnerNet:
  269. def __init__(self):
  270. self.number = 1
  271. class Net(nn.Cell):
  272. def __init__(self):
  273. super(Net, self).__init__()
  274. self.inner_net = InnerNet()
  275. def construct(self, x):
  276. if isinstance(self.inner_net, InnerNet):
  277. return x + 10
  278. return x
  279. net = Net()
  280. out = net(5)
  281. assert out == 15
  282. def test_fallback_raise_error_not_class_type():
  283. """
  284. Feature: JIT Fallback
  285. Description: Test ms_class in graph.
  286. Expectation: No exception.
  287. """
  288. with pytest.raises(TypeError):
  289. @ms_class
  290. def func(x, y):
  291. return x + y
  292. func(1, 2)
  293. def test_fallback_raise_error_not_class_instance():
  294. """
  295. Feature: JIT Fallback
  296. Description: Test ms_class in graph.
  297. Expectation: No exception.
  298. """
  299. @ms_class
  300. class InnerNet:
  301. def __init__(self):
  302. self.number = 1
  303. class Net(nn.Cell):
  304. def construct(self):
  305. out = InnerNet().number
  306. return out
  307. with pytest.raises(ValueError):
  308. net = Net()
  309. net()
  310. def test_fallback_raise_error_decorate_cell():
  311. """
  312. Feature: JIT Fallback
  313. Description: Test ms_class in graph.
  314. Expectation: No exception.
  315. """
  316. @ms_class
  317. class Net(nn.Cell):
  318. def construct(self, x):
  319. return x
  320. with pytest.raises(TypeError):
  321. x = Tensor(1)
  322. net = Net()
  323. net(x)