# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """test vmap in graph mode""" import pytest import mindspore.nn as nn import mindspore.context as context import mindspore.ops.operations as P from mindspore import Tensor from mindspore import dtype as mstype from mindspore.ops.functional import vmap context.set_context(mode=context.GRAPH_MODE) class ThreeInputsTwoOutputsNet(nn.Cell): def construct(self, x, y, z): return x + y, z def test_lambda_fn(): """ Feature: vmap Description: The first argument of `vmap` is a lambda function. Expectation: throw TypeError:"Parse Lambda Function Fail. Node type must be Lambda, but got Call." """ x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) z_hat = 1 with pytest.raises(TypeError) as ex: vmap(lambda x, y, z: x + y + z, in_axes=(1, 1, None), out_axes=0)(x_hat, y_hat, z_hat) assert "Parse Lambda Function Fail. Node type must be Lambda, but got Call." in str(ex.value) def test_single_op(): """ Feature: vmap Description: The first argument of `vmap` is a single primitive. Expectation: throw RuntimeError:"'VmapOperation' arg0 Prim: S-Prim-Add cast to 'FuncGraphAbstractClosure' failed." """ x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) with pytest.raises(RuntimeError) as ex: vmap(P.Add(), in_axes=(1, 1), out_axes=0)(x_hat, y_hat) assert "'VmapOperation' arg0 Prim: S-Prim-Add cast to 'FuncGraphAbstractClosure' failed." in str(ex.value) def test_none_in_axes(): """ Feature: vmap Description: The `in_axis` argument of `vmap` is a single None, and it's invalid when apply `vmap`. Expectation: throw RuntimeError:"The 'in_axes' of 'vmap' cannot be a single None." """ x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) z_hat = 1 with pytest.raises(RuntimeError) as ex: vmap(ThreeInputsTwoOutputsNet(), in_axes=None, out_axes=0)(x_hat, y_hat, z_hat) assert "The 'in_axes' of 'vmap' cannot be a single None." in str(ex.value) def test_none_out_axes(): """ Feature: vmap Description: The `out_axes` argument of `vmap` is a nested None, and it's invalid when apply `vmap`. Expectation: throw RuntimeError:"The 'out_axes' of 'vmap' cannot be all None, but got (None, None, None, (None, None))." """ x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) z_hat = 1 with pytest.raises(RuntimeError) as ex: vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=(None, None, None, (None, None)))(x_hat, y_hat, z_hat) assert "The 'out_axes' of 'vmap' cannot be all None, but got (None, None, None, (None, None))." in str(ex.value) def test_mismatch_out_axes(): """ Feature: vmap Description: The `out_axes` of `vmap` sets to (0, 0, 0), but the outputs of `fn` is x + y, z. Expectation: throw RuntimeError:"The size of vmap's 'out_axes' should be equal to the number of results of 'fn': 2, but got size: 3." """ x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) z_hat = 1 with pytest.raises(RuntimeError) as ex: vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=(0, 0, 0))(x_hat, y_hat, z_hat) assert "The size of vmap's 'out_axes' should be equal to the number of results of 'fn': 2, but got size: 3." \ in str(ex.value) def test_axis_type(): """ Feature: vmap Description: The `in_axes` of `vmap` contains elements of Float type. Expectation: throw RuntimeError:"The axis in vmap's 'in_axes' should be a None or a scalar of type Int64Imm, but got a 1." """ x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) z_hat = 1 with pytest.raises(RuntimeError) as ex: vmap(ThreeInputsTwoOutputsNet(), in_axes=(1., 1., None), out_axes=0)(x_hat, y_hat, z_hat) assert "The axis in vmap's 'in_axes' should be a None or a scalar of type Int64Imm, but got a 1." in str(ex.value) def test_axis_out_of_bounds(): """ Feature: vmap Description: The dimension of X is 2, but the corresponding axis -3 is set. Expectation: throw RuntimeError:"The axis: -3 in 'in_axes' is out of bounds for array of dimension [-2,2)." """ x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) z_hat = 1 with pytest.raises(RuntimeError) as ex: vmap(ThreeInputsTwoOutputsNet(), in_axes=(-3, 2, None), out_axes=0)(x_hat, y_hat, z_hat) assert "The axis: -3 in 'in_axes' is out of bounds for array of dimension [-2,2)." in str(ex.value) def test_mismatch_none_axis(): """ Feature: vmap Description: The source axis of the first output of `fn` is non-None, but the `out_axes` for that is None, it's invalid when apply `vmap`. Expectation: throw RuntimeError:"It is invalid that source is not None and dst is None." """ x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) z_hat = 1 with pytest.raises(RuntimeError) as ex: vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=(None, 0))(x_hat, y_hat, z_hat) assert "It is invalid that source is not None and dst is None." in str(ex.value) def test_mismatch_parameters_number(): """ Feature: vmap Description: The arguments of the cell is (x, y, z), but the arguments of vmap-ed function is (x_hat, y_hat). Expectation: throw TypeError:"The parameters number of the function is 3, but the number of provided arguments is 2." """ x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) with pytest.raises(TypeError) as ex: vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=0)(x_hat, y_hat) assert "The parameters number of the function is 3, but the number of provided arguments is 2." in str(ex.value) def test_mismatch_axis_size(): """ Feature: vmap Description: The `axis_size` of X is 3, and the `axis_size` of Y is 2, they are not equal, vmap needs to ensure that the `axis_size` of all parameters are uniform. Expectation: throw RuntimeError:"The 'axis_size' of each argument in the scope of 'vmap' should be equal, but got 3 and 2." """ x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) z_hat = 1 with pytest.raises(RuntimeError) as ex: vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 0, None), out_axes=0)(x_hat, y_hat, z_hat) assert "The 'axis_size' of each argument in the scope of 'vmap' should be equal, but got 3 and 2." in str(ex.value) def test_vmap_non_input(): """ Feature: vmap Description: The arguments of the cell is empty, it's invalid when apply `vmap`. Expectation: throw RuntimeError:"Failed to get 'axis_size' within the scope of vmap." """ class NonInputSingleOutputNet(nn.Cell): def construct(self): return 1 with pytest.raises(RuntimeError) as ex: vmap(NonInputSingleOutputNet())() assert "Failed to get 'axis_size' within the scope of vmap." in str(ex.value) def test_non_fn(): """ Feature: vmap Description: The first argument of `vmap` not provided, which is required positional argument. Expectation: throw TypeError:"vmap() missing 1 required positional argument: 'fn'" """ x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) z_hat = 1 with pytest.raises(TypeError) as ex: vmap(in_axes=(1, 1, None), out_axes=0)(x_hat, y_hat, z_hat) assert "vmap() missing 1 required positional argument: 'fn'" in str(ex.value) def test_scalar_with_non_zero_axis(): """ Feature: vmap Description: The second output of `fn` is a scalar with source axis None, but get a destination axis 1, and it's invalid when apply `vmap`. Expectation: throw RuntimeError:"The axis: 1 in 'out_axes' is out of bounds for array of dimension [-1,1)." """ x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) z_hat = 1 with pytest.raises(RuntimeError) as ex: vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=(0, 1))(x_hat, y_hat, z_hat) assert "The axis: 1 in 'out_axes' is out of bounds for array of dimension [-1,1)." in str(ex.value)