You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_partial.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test partial"""
  16. from functools import partial
  17. import numpy as np
  18. import pytest
  19. from mindspore import nn, Tensor, context
  20. context.set_context(mode=context.GRAPH_MODE)
  21. def test_partial_pos_arg():
  22. """
  23. Feature: ALL TO ALL
  24. Description: test cases for partial_pos_arg
  25. Expectation: the result match given one
  26. """
  27. class Net(nn.Cell):
  28. def __init__(self):
  29. super(Net, self).__init__()
  30. def show(self, x, y, z):
  31. return x, y, z
  32. def construct(self, x, y, z):
  33. f = partial(self.show, x)
  34. ret = f(y, z)
  35. return ret
  36. class Net2(nn.Cell):
  37. def __init__(self):
  38. super(Net2, self).__init__()
  39. self.show = lambda x, y, z: (x, y, z)
  40. def construct(self, x, y, z):
  41. f = partial(self.show, x)
  42. ret = f(y, z)
  43. return ret
  44. x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
  45. y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
  46. z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
  47. for net in [Net(), Net2()]:
  48. net(x, y, z)
  49. def test_partial_key_ward_arg():
  50. """
  51. Feature: ALL TO ALL
  52. Description: test cases for partial_key_ward_arg
  53. Expectation: the result match given one
  54. """
  55. class Net(nn.Cell):
  56. def __init__(self):
  57. super(Net, self).__init__()
  58. def show(self, x, y, z):
  59. return x, y, z
  60. def construct(self, x, y, z):
  61. f = partial(self.show, x=x)
  62. ret = f(y=y, z=z)
  63. return ret
  64. class Net2(nn.Cell):
  65. def __init__(self):
  66. super(Net2, self).__init__()
  67. self.show = lambda x, y, z: (x, y, z)
  68. def construct(self, x, y, z):
  69. f = partial(self.show, x=x)
  70. ret = f(y=y, z=z)
  71. return ret
  72. x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
  73. y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
  74. z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
  75. for net in [Net(), Net2()]:
  76. net(x, y, z)
  77. def test_partial_key_ward_arg_update():
  78. """
  79. Feature: ALL TO ALL
  80. Description: test cases for partial_key_ward_arg_update
  81. Expectation: the result match given one
  82. """
  83. class Net(nn.Cell):
  84. def __init__(self):
  85. super(Net, self).__init__()
  86. def show(self, x, y, z):
  87. return x, y, z
  88. def construct(self, x, y, z):
  89. f = partial(self.show, x=x, y=y)
  90. ret = f(y=y, z=z)
  91. return ret
  92. class Net2(nn.Cell):
  93. def __init__(self):
  94. super(Net2, self).__init__()
  95. self.show = lambda x, y, z: (x, y, z)
  96. def construct(self, x, y, z):
  97. f = partial(self.show, x=x, y=y)
  98. ret = f(y=y, z=z)
  99. return ret
  100. x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
  101. y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
  102. z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
  103. for net in [Net(), Net2()]:
  104. net(x, y, z)
  105. def test_partial_key_ward_arg_and_pos_arg():
  106. """
  107. Feature: ALL TO ALL
  108. Description: test cases for partial_key_ward_arg_and_pos_arg
  109. Expectation: the result match given one
  110. """
  111. class Net(nn.Cell):
  112. def __init__(self):
  113. super(Net, self).__init__()
  114. def show(self, x, y, z):
  115. return x, y, z
  116. def construct(self, x, y, z):
  117. f = partial(self.show, y=y)
  118. ret = f(2, z=z)
  119. return ret
  120. class Net2(nn.Cell):
  121. def __init__(self):
  122. super(Net2, self).__init__()
  123. self.show = lambda x, y, z: (x, y, z)
  124. def construct(self, x, y, z):
  125. f = partial(self.show, y=y)
  126. ret = f(2, z=z)
  127. return ret
  128. x = Tensor(np.arange(3).reshape((3,)).astype(np.float32))
  129. y = Tensor(np.arange(3 * 4).reshape((3, 4)).astype(np.float32))
  130. z = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(np.float32))
  131. for net in [Net(), Net2()]:
  132. net(x, y, z)
  133. def test_partial_pos_arg_const():
  134. """
  135. Feature: ALL TO ALL
  136. Description: test cases for partial_pos_arg_const
  137. Expectation: the result match given one
  138. """
  139. class Net(nn.Cell):
  140. def __init__(self):
  141. super(Net, self).__init__()
  142. def show(self, x, y, z):
  143. return x, y, z
  144. def construct(self):
  145. f = partial(self.show, 1)
  146. ret = f(2, 3)
  147. return ret
  148. class Net2(nn.Cell):
  149. def __init__(self):
  150. super(Net2, self).__init__()
  151. self.show = lambda x, y, z: (x, y, z)
  152. def construct(self):
  153. f = partial(self.show, 1)
  154. ret = f(2, 3)
  155. return ret
  156. for net in [Net(), Net2()]:
  157. assert net() == (1, 2, 3)
  158. def test_partial_key_ward_arg_const():
  159. """
  160. Feature: ALL TO ALL
  161. Description: test cases for partial_key_ward_arg_const
  162. Expectation: the result match given one
  163. """
  164. class Net(nn.Cell):
  165. def __init__(self):
  166. super(Net, self).__init__()
  167. def show(self, x, y, z):
  168. return x, y, z
  169. def construct(self):
  170. f = partial(self.show, x=1)
  171. ret = f(y=2, z=3)
  172. return ret
  173. class Net2(nn.Cell):
  174. def __init__(self):
  175. super(Net2, self).__init__()
  176. self.show = lambda x, y, z: (x, y, z)
  177. def construct(self):
  178. f = partial(self.show, x=1)
  179. ret = f(y=2, z=3)
  180. return ret
  181. for net in [Net(), Net2()]:
  182. assert net() == (1, 2, 3)
  183. def test_partial_key_ward_arg_update_const():
  184. """
  185. Feature: ALL TO ALL
  186. Description: test cases for partial_key_ward_arg_update_const
  187. Expectation: the result match given one
  188. """
  189. class Net(nn.Cell):
  190. def __init__(self):
  191. super(Net, self).__init__()
  192. def show(self, x, y, z):
  193. return x, y, z
  194. def construct(self):
  195. f = partial(self.show, x=1, y=2)
  196. ret = f(y=3, z=4)
  197. return ret
  198. class Net2(nn.Cell):
  199. def __init__(self):
  200. super(Net2, self).__init__()
  201. self.show = lambda x, y, z: (x, y, z)
  202. def construct(self):
  203. f = partial(self.show, x=1, y=2)
  204. ret = f(y=3, z=4)
  205. return ret
  206. for net in [Net(), Net2()]:
  207. assert net() == (1, 3, 4)
  208. def test_partial_key_ward_arg_and_pos_arg_const():
  209. """
  210. Feature: ALL TO ALL
  211. Description: test cases for partial_key_ward_arg_and_pos_arg_const
  212. Expectation: the result match given one
  213. """
  214. class Net(nn.Cell):
  215. def __init__(self):
  216. super(Net, self).__init__()
  217. def show(self, x, y, z):
  218. return x, y, z
  219. def construct(self):
  220. f = partial(self.show, y=2)
  221. ret = f(1, z=3)
  222. return ret
  223. class Net2(nn.Cell):
  224. def __init__(self):
  225. super(Net2, self).__init__()
  226. self.show = lambda x, y, z: (x, y, z)
  227. def construct(self):
  228. f = partial(self.show, y=2)
  229. ret = f(1, z=3)
  230. return ret
  231. for net in [Net(), Net2()]:
  232. assert net() == (1, 2, 3)
  233. def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_x():
  234. """
  235. Feature: ALL TO ALL
  236. Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_x
  237. Expectation: the result match given one
  238. """
  239. class Net(nn.Cell):
  240. def __init__(self):
  241. super(Net, self).__init__()
  242. def show(self, x, y, z):
  243. return x, y, z
  244. def construct(self):
  245. f = partial(self.show, x=1)
  246. ret = f(1, 2, 3)
  247. return ret
  248. class Net2(nn.Cell):
  249. def __init__(self):
  250. super(Net2, self).__init__()
  251. self.show = lambda x, y, z: (x, y, z)
  252. def construct(self):
  253. f = partial(self.show, x=1)
  254. ret = f(1, 2, 3)
  255. return ret
  256. for net in [Net(), Net2()]:
  257. with pytest.raises(TypeError) as ex:
  258. net()
  259. assert "Multiply values for specific argument: x" in str(ex.value)
  260. def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_y():
  261. """
  262. Feature: ALL TO ALL
  263. Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_y
  264. Expectation: the result match given one
  265. """
  266. class Net(nn.Cell):
  267. def __init__(self):
  268. super(Net, self).__init__()
  269. def show(self, x, y, z):
  270. return x, y, z
  271. def construct(self):
  272. f = partial(self.show, y=2)
  273. ret = f(1, 2, z=3)
  274. return ret
  275. class Net2(nn.Cell):
  276. def __init__(self):
  277. super(Net2, self).__init__()
  278. self.show = lambda x, y, z: (x, y, z)
  279. def construct(self):
  280. f = partial(self.show, y=2)
  281. ret = f(1, 2, z=3)
  282. return ret
  283. for net in [Net(), Net2()]:
  284. with pytest.raises(TypeError) as ex:
  285. net()
  286. assert "Multiply values for specific argument: y" in str(ex.value)
  287. def test_partial_key_ward_arg_and_pos_arg_const_multi_assign_z():
  288. """
  289. Feature: ALL TO ALL
  290. Description: test cases for partial_key_ward_arg_and_pos_arg_const_multi_assign_z
  291. Expectation: the result match given one
  292. """
  293. class Net(nn.Cell):
  294. def __init__(self):
  295. super(Net, self).__init__()
  296. def show(self, x, y, z):
  297. return x, y, z
  298. def construct(self):
  299. f = partial(self.show, z=1)
  300. ret = f(1, 2, 3)
  301. return ret
  302. class Net2(nn.Cell):
  303. def __init__(self):
  304. super(Net2, self).__init__()
  305. self.show = lambda x, y, z: (x, y, z)
  306. def construct(self):
  307. f = partial(self.show, z=1)
  308. ret = f(1, 2, 3)
  309. return ret
  310. for net in [Net(), Net2()]:
  311. with pytest.raises(TypeError) as ex:
  312. net()
  313. assert "Multiply values for specific argument: z" in str(ex.value)