You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_vmap.py 9.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """test vmap in graph mode"""
  16. import pytest
  17. import mindspore.nn as nn
  18. import mindspore.context as context
  19. import mindspore.ops.operations as P
  20. from mindspore import Tensor
  21. from mindspore import dtype as mstype
  22. from mindspore.ops.functional import vmap
  23. context.set_context(mode=context.GRAPH_MODE)
  24. class ThreeInputsTwoOutputsNet(nn.Cell):
  25. def construct(self, x, y, z):
  26. return x + y, z
  27. def test_lambda_fn():
  28. """
  29. Feature: vmap
  30. Description: The first argument of `vmap` is a lambda function.
  31. Expectation: throw TypeError:"Parse Lambda Function Fail. Node type must be Lambda, but got Call."
  32. """
  33. x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  34. y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  35. z_hat = 1
  36. with pytest.raises(TypeError) as ex:
  37. vmap(lambda x, y, z: x + y + z, in_axes=(1, 1, None), out_axes=0)(x_hat, y_hat, z_hat)
  38. assert "Parse Lambda Function Fail. Node type must be Lambda, but got Call." in str(ex.value)
  39. def test_single_op():
  40. """
  41. Feature: vmap
  42. Description: The first argument of `vmap` is a single primitive.
  43. Expectation: throw RuntimeError:"'VmapOperation' arg0 Prim: S-Prim-Add cast to 'FuncGraphAbstractClosure' failed."
  44. """
  45. x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  46. y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  47. with pytest.raises(RuntimeError) as ex:
  48. vmap(P.Add(), in_axes=(1, 1), out_axes=0)(x_hat, y_hat)
  49. assert "'VmapOperation' arg0 Prim: S-Prim-Add cast to 'FuncGraphAbstractClosure' failed." in str(ex.value)
  50. def test_none_in_axes():
  51. """
  52. Feature: vmap
  53. Description: The `in_axis` argument of `vmap` is a single None, and it's invalid when apply `vmap`.
  54. Expectation: throw RuntimeError:"The 'in_axes' of 'vmap' cannot be a single None."
  55. """
  56. x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  57. y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  58. z_hat = 1
  59. with pytest.raises(RuntimeError) as ex:
  60. vmap(ThreeInputsTwoOutputsNet(), in_axes=None, out_axes=0)(x_hat, y_hat, z_hat)
  61. assert "The 'in_axes' of 'vmap' cannot be a single None." in str(ex.value)
  62. def test_none_out_axes():
  63. """
  64. Feature: vmap
  65. Description: The `out_axes` argument of `vmap` is a nested None, and it's invalid when apply `vmap`.
  66. Expectation: throw RuntimeError:"The 'out_axes' of 'vmap' cannot be all None, but got
  67. (None, None, None, (None, None))."
  68. """
  69. x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  70. y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  71. z_hat = 1
  72. with pytest.raises(RuntimeError) as ex:
  73. vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None),
  74. out_axes=(None, None, None, (None, None)))(x_hat, y_hat, z_hat)
  75. assert "The 'out_axes' of 'vmap' cannot be all None, but got (None, None, None, (None, None))." in str(ex.value)
  76. def test_mismatch_out_axes():
  77. """
  78. Feature: vmap
  79. Description: The `out_axes` of `vmap` sets to (0, 0, 0), but the outputs of `fn` is x + y, z.
  80. Expectation: throw RuntimeError:"The size of vmap's 'out_axes' should be equal to the number of results of 'fn': 2,
  81. but got size: 3."
  82. """
  83. x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  84. y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  85. z_hat = 1
  86. with pytest.raises(RuntimeError) as ex:
  87. vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=(0, 0, 0))(x_hat, y_hat, z_hat)
  88. assert "The size of vmap's 'out_axes' should be equal to the number of results of 'fn': 2, but got size: 3." \
  89. in str(ex.value)
  90. def test_axis_type():
  91. """
  92. Feature: vmap
  93. Description: The `in_axes` of `vmap` contains elements of Float type.
  94. Expectation: throw RuntimeError:"The axis in vmap's 'in_axes' should be a None or a scalar of type Int64Imm,
  95. but got a 1."
  96. """
  97. x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  98. y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  99. z_hat = 1
  100. with pytest.raises(RuntimeError) as ex:
  101. vmap(ThreeInputsTwoOutputsNet(), in_axes=(1., 1., None), out_axes=0)(x_hat, y_hat, z_hat)
  102. assert "The axis in vmap's 'in_axes' should be a None or a scalar of type Int64Imm, but got a 1." in str(ex.value)
  103. def test_axis_out_of_bounds():
  104. """
  105. Feature: vmap
  106. Description: The dimension of X is 2, but the corresponding axis -3 is set.
  107. Expectation: throw RuntimeError:"The axis: -3 in 'in_axes' is out of bounds for array of dimension [-2,2)."
  108. """
  109. x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  110. y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  111. z_hat = 1
  112. with pytest.raises(RuntimeError) as ex:
  113. vmap(ThreeInputsTwoOutputsNet(), in_axes=(-3, 2, None), out_axes=0)(x_hat, y_hat, z_hat)
  114. assert "The axis: -3 in 'in_axes' is out of bounds for array of dimension [-2,2)." in str(ex.value)
  115. def test_mismatch_none_axis():
  116. """
  117. Feature: vmap
  118. Description: The source axis of the first output of `fn` is non-None, but the `out_axes` for that is None,
  119. it's invalid when apply `vmap`.
  120. Expectation: throw RuntimeError:"It is invalid that source is not None and dst is None."
  121. """
  122. x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  123. y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  124. z_hat = 1
  125. with pytest.raises(RuntimeError) as ex:
  126. vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=(None, 0))(x_hat, y_hat, z_hat)
  127. assert "It is invalid that source is not None and dst is None." in str(ex.value)
  128. def test_mismatch_parameters_number():
  129. """
  130. Feature: vmap
  131. Description: The arguments of the cell is (x, y, z), but the arguments of vmap-ed function is (x_hat, y_hat).
  132. Expectation: throw TypeError:"The parameters number of the function is 3, but the number of provided arguments
  133. is 2."
  134. """
  135. x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  136. y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  137. with pytest.raises(TypeError) as ex:
  138. vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=0)(x_hat, y_hat)
  139. assert "The parameters number of the function is 3, but the number of provided arguments is 2." in str(ex.value)
  140. def test_mismatch_axis_size():
  141. """
  142. Feature: vmap
  143. Description: The `axis_size` of X is 3, and the `axis_size` of Y is 2, they are not equal, vmap needs to ensure
  144. that the `axis_size` of all parameters are uniform.
  145. Expectation: throw RuntimeError:"The 'axis_size' of each argument in the scope of 'vmap' should be equal,
  146. but got 3 and 2."
  147. """
  148. x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  149. y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  150. z_hat = 1
  151. with pytest.raises(RuntimeError) as ex:
  152. vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 0, None), out_axes=0)(x_hat, y_hat, z_hat)
  153. assert "The 'axis_size' of each argument in the scope of 'vmap' should be equal, but got 3 and 2." in str(ex.value)
  154. def test_vmap_non_input():
  155. """
  156. Feature: vmap
  157. Description: The arguments of the cell is empty, it's invalid when apply `vmap`.
  158. Expectation: throw RuntimeError:"Failed to get 'axis_size' within the scope of vmap."
  159. """
  160. class NonInputSingleOutputNet(nn.Cell):
  161. def construct(self):
  162. return 1
  163. with pytest.raises(RuntimeError) as ex:
  164. vmap(NonInputSingleOutputNet())()
  165. assert "Failed to get 'axis_size' within the scope of vmap." in str(ex.value)
  166. def test_non_fn():
  167. """
  168. Feature: vmap
  169. Description: The first argument of `vmap` not provided, which is required positional argument.
  170. Expectation: throw TypeError:"vmap() missing 1 required positional argument: 'fn'"
  171. """
  172. x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  173. y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  174. z_hat = 1
  175. with pytest.raises(TypeError) as ex:
  176. vmap(in_axes=(1, 1, None), out_axes=0)(x_hat, y_hat, z_hat)
  177. assert "vmap() missing 1 required positional argument: 'fn'" in str(ex.value)
  178. def test_scalar_with_non_zero_axis():
  179. """
  180. Feature: vmap
  181. Description: The second output of `fn` is a scalar with source axis None, but get a destination axis 1, and it's
  182. invalid when apply `vmap`.
  183. Expectation: throw RuntimeError:"The axis: 1 in 'out_axes' is out of bounds for array of dimension [-1,1)."
  184. """
  185. x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  186. y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32)
  187. z_hat = 1
  188. with pytest.raises(RuntimeError) as ex:
  189. vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=(0, 1))(x_hat, y_hat, z_hat)
  190. assert "The axis: 1 in 'out_axes' is out of bounds for array of dimension [-1,1)." in str(ex.value)