| @@ -848,23 +848,23 @@ ValuePtr CheckAxes(const AbstractBasePtr &axes_abs, const bool &is_in_axes = fal | |||
| ValueSequencePtr in_axes_seq = dyn_cast<ValueSequence>(axes_value); | |||
| int in_axes_size = SizeToInt(in_axes_seq->size()); | |||
| if (nparam != in_axes_size) { | |||
| MS_LOG(EXCEPTION) << "When vmap`s `" << axes_name | |||
| << "` is a tuple or list, and its size must be equal to the number of arguments of `fn`: " | |||
| MS_LOG(EXCEPTION) << "When vmap`s '" << axes_name | |||
| << "' is a tuple or list, and its size must be equal to the number of arguments of 'fn': " | |||
| << nparam << ", but got size: " << in_axes_size << "."; | |||
| } | |||
| } | |||
| bool elem_all_none = IsAxesAllNone(axes_value); | |||
| if (elem_all_none) { | |||
| MS_LOG(EXCEPTION) << "The `" << axes_name << "` of `vmap` cannot be all None, but got " << axes_value->ToString() | |||
| MS_LOG(EXCEPTION) << "The '" << axes_name << "' of 'vmap' cannot be all None, but got " << axes_value->ToString() | |||
| << "."; | |||
| } | |||
| } else { | |||
| axes_value = axes_abs->BuildValue(); | |||
| MS_EXCEPTION_IF_NULL(axes_value); | |||
| if (axes_value->isa<None>()) { | |||
| MS_LOG(EXCEPTION) << "The `" << axes_name << "` of `vmap` cannot be a single None."; | |||
| MS_LOG(EXCEPTION) << "The '" << axes_name << "' of 'vmap' cannot be a single None."; | |||
| } else if (!axes_value->isa<Int64Imm>()) { | |||
| MS_LOG(EXCEPTION) << "The axis in vmap`s `" << axes_name << "` can only be of type Int or None, but got " | |||
| MS_LOG(EXCEPTION) << "The axis in vmap`s '" << axes_name << "' can only be of type Int or None, but got " | |||
| << axes_abs->ToString() << "."; | |||
| } | |||
| } | |||
| @@ -892,7 +892,7 @@ FuncGraphPtr VmapOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp | |||
| auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn); | |||
| if (real_fn == nullptr) { | |||
| MS_LOG(EXCEPTION) << "'VmapOperation' arg0 " << fn->ToString() << " cast to `FuncGraphAbstractClosure` failed."; | |||
| MS_LOG(EXCEPTION) << "'VmapOperation' arg0 " << fn->ToString() << " cast to 'FuncGraphAbstractClosure' failed."; | |||
| } | |||
| FuncGraphPtr orig_graph = real_fn->func_graph(); | |||
| @@ -623,7 +623,7 @@ AbstractBasePtr ReduceDim(int *axis, const AbstractBasePtr &orig_abs, int *axis_ | |||
| ShapeVector orig_shape = dyn_cast<abstract::Shape>(orig_abs->BuildShape())->shape(); | |||
| int shape_len = SizeToInt(orig_shape.size()); | |||
| if (*axis < -shape_len || *axis >= shape_len) { | |||
| MS_LOG(EXCEPTION) << "ValueError: The axis: " << *axis << " in `in_axes` is out of bounds for array of dimension [" | |||
| MS_LOG(EXCEPTION) << "ValueError: The axis: " << *axis << " in 'in_axes' is out of bounds for array of dimension [" | |||
| << -shape_len << "," << shape_len << ")."; | |||
| } | |||
| *axis = *axis < 0 ? shape_len + *axis : *axis; | |||
| @@ -631,7 +631,7 @@ AbstractBasePtr ReduceDim(int *axis, const AbstractBasePtr &orig_abs, int *axis_ | |||
| if (*axis_size == -1) { | |||
| *axis_size = LongToInt(temp_axes_size); | |||
| } else if (*axis_size != temp_axes_size) { | |||
| MS_LOG(EXCEPTION) << "The `axes_size` of each argument in the scope of `vmap` should be equal, but got " | |||
| MS_LOG(EXCEPTION) << "The 'axis_size' of each argument in the scope of 'vmap' should be equal, but got " | |||
| << *axis_size << " and " << temp_axes_size << "."; | |||
| } | |||
| (void)orig_shape.erase(orig_shape.begin() + *axis); | |||
| @@ -666,15 +666,15 @@ AbstractBasePtr GetLogicalViewAbs(const AbstractBasePtr &physical_view_abs, cons | |||
| return std::make_shared<AbstractTuple>(logical_view_abs_list); | |||
| } | |||
| ValuePtr in_axis = in_axes; | |||
| if (!in_axis->isa<Int64Imm>() && !in_axis->isa<None>()) { | |||
| MS_LOG(EXCEPTION) << "The axis in vmap's `in_axes` should be a None or a scalar of type Int64Imm, but got a " | |||
| << in_axis->ToString() << "."; | |||
| } | |||
| if (in_axis->isa<Int64Imm>()) { | |||
| int axis = dyn_cast<Int64Imm>(in_axis)->value(); | |||
| auto logical_view_abs = ReduceDim(&axis, physical_view_abs, axis_size); | |||
| return logical_view_abs; | |||
| } | |||
| if (!in_axis->isa<None>()) { | |||
| MS_LOG(EXCEPTION) << "The axis in vmap's 'in_axes' should be a None or a scalar of type Int64Imm, but got a " | |||
| << in_axis->ToString() << "."; | |||
| } | |||
| // in_axis is None. | |||
| return physical_view_abs; | |||
| } | |||
| @@ -688,7 +688,7 @@ AbstractBasePtr ExtendDim(int *axis, const AbstractBasePtr &orig_abs, int axis_s | |||
| } | |||
| int shape_len = SizeToInt(orig_shape.size() + 1); | |||
| if (*axis < -shape_len || *axis >= shape_len) { | |||
| MS_LOG(EXCEPTION) << "ValueError: The axis: " << *axis << " in `out_axes` is out of bounds for array of dimension [" | |||
| MS_LOG(EXCEPTION) << "ValueError: The axis: " << *axis << " in 'out_axes' is out of bounds for array of dimension [" | |||
| << -shape_len << "," << shape_len << ")."; | |||
| } | |||
| *axis = *axis < 0 ? shape_len + *axis : *axis; | |||
| @@ -700,7 +700,7 @@ AbstractBasePtr ExtendDim(int *axis, const AbstractBasePtr &orig_abs, int axis_s | |||
| } else if (orig_abs->isa<AbstractScalar>()) { | |||
| out_abs = std::make_shared<abstract::AbstractTensor>(orig_abs, new_shape); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "The outputs of vmap's `fn` should be consisting of tensors or constants, but got " | |||
| MS_LOG(EXCEPTION) << "The outputs of vmap's 'fn' should be consisting of tensors or constants, but got " | |||
| << orig_abs->ToString() << "."; | |||
| } | |||
| return out_abs; | |||
| @@ -715,7 +715,7 @@ AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, cons | |||
| auto out_axes_seq = dyn_cast<ValueSequeue>(out_axes); | |||
| if (out_axes_seq != nullptr) { | |||
| if (logical_view_abs_list.size() != out_axes_seq->size()) { | |||
| MS_LOG(EXCEPTION) << "The size of vmap's `out_axes` should be equal to the number of results of `fn`: " | |||
| MS_LOG(EXCEPTION) << "The size of vmap's 'out_axes' should be equal to the number of results of 'fn': " | |||
| << logical_view_abs_list.size() << ", but got size: " << out_axes_seq->size() << "."; | |||
| } | |||
| } | |||
| @@ -737,7 +737,7 @@ AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, cons | |||
| } else if (sub_out_axes->isa<None>()) { | |||
| return arg_spec; | |||
| } | |||
| MS_LOG(EXCEPTION) << "The axis in vmap's `out_axes` should be a None or a scalar of type Int64Imm, but got a " | |||
| MS_LOG(EXCEPTION) << "The axis in vmap's 'out_axes' should be a None or a scalar of type Int64Imm, but got a " | |||
| << sub_out_axes->ToString() << "."; | |||
| }); | |||
| if (logical_view_abs->isa<AbstractList>()) { | |||
| @@ -746,18 +746,24 @@ AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, cons | |||
| return std::make_shared<AbstractTuple>(physical_view_abs_list); | |||
| } | |||
| int axis = 0; | |||
| if (out_axes->isa<None>()) { | |||
| return logical_view_abs; | |||
| } else if (out_axes->isa<ValueSequeue>()) { | |||
| ValueSequeuePtr out_axes_seq = dyn_cast<ValueSequeue>(out_axes); | |||
| // for the single output case, outputs: A, and out_axes: 1 or (1,). | |||
| ValuePtr sub_out_axes = out_axes; | |||
| ValueSequeuePtr out_axes_seq = dyn_cast<ValueSequeue>(out_axes); | |||
| if (out_axes_seq != nullptr) { | |||
| if (out_axes_seq->size() != 1) { | |||
| MS_LOG(EXCEPTION) << "The size of vmap's `out_axes` should be equal to the result size: 1, but got size: " | |||
| MS_LOG(EXCEPTION) << "The size of vmap's 'out_axes' should be equal to the result size: 1, but got size: " | |||
| << out_axes_seq->size() << "."; | |||
| } | |||
| axis = dyn_cast<Int64Imm>((*out_axes_seq)[0])->value(); | |||
| } else if (out_axes->isa<Int64Imm>()) { | |||
| axis = dyn_cast<Int64Imm>(out_axes)->value(); | |||
| sub_out_axes = (*out_axes_seq)[0]; | |||
| } | |||
| int axis = 0; | |||
| auto axis_int_ptr = dyn_cast<Int64Imm>(sub_out_axes); | |||
| if (axis_int_ptr != nullptr) { | |||
| axis = LongToInt(axis_int_ptr->value()); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "The axis in vmap's 'out_axes' should be a None or a scalar of type Int64Imm, but got a " | |||
| << sub_out_axes->ToString() << "."; | |||
| } | |||
| return ExtendDim(&axis, logical_view_abs, axis_size); | |||
| } | |||
| @@ -139,7 +139,7 @@ void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) { | |||
| } | |||
| py::function PrimitivePy::GetVmapRuleFunction(const bool is_side_effect, int axis_size) { | |||
| static const char *const get_vmap_rule_func_name = "get_vmap_rule"; | |||
| constexpr char get_vmap_rule_func_name[] = "get_vmap_rule"; | |||
| if (py::hasattr(python_obj_, get_vmap_rule_func_name)) { | |||
| py::function fn = python_obj_.attr(get_vmap_rule_func_name)().cast<py::function>(); | |||
| return fn; | |||
| @@ -122,20 +122,19 @@ def vmap_general_rule(prim, axis_size): | |||
| vals_in_tuple = () | |||
| for val_in in args: | |||
| val, dim = val_in | |||
| if isinstance(val, Tensor): | |||
| # Handle case such as args:(..., (A, 0), (B, 1), ...) | |||
| if dim is None: | |||
| val = _broadcast_by_axis(val, 0, axis_size) | |||
| dim = 0 | |||
| out = P.Unstack(dim)(val) | |||
| out = () | |||
| if dim is None: | |||
| # Handle case such as args:(..., (A, None), (1, None), ...) | |||
| for _ in range(axis_size): | |||
| out = out + (val,) | |||
| else: | |||
| # Handle scalar case such as args:(..., (1, None), ...) | |||
| if dim is not None: | |||
| if isinstance(val, Tensor): | |||
| # Handle case such as args:(..., (A, 0), (B, 1), ...) | |||
| out = P.Unstack(dim)(val) | |||
| else: | |||
| _raise_value_error("A variable of type other than `Tensor` is accepted, " | |||
| "but the source axis is not `None`") | |||
| out = () | |||
| for _ in range(axis_size): | |||
| out = out + (val,) | |||
| vals_in_tuple = vals_in_tuple + (out,) | |||
| if wrapped_tuple: | |||
| @@ -0,0 +1,354 @@ | |||
| # Copyright 2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test vmap in graph mode""" | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.numpy as mnp | |||
| import mindspore.context as context | |||
| import mindspore.ops.operations as P | |||
| import mindspore.ops.functional as F | |||
| from mindspore import dtype as mstype | |||
| from mindspore.common import Tensor | |||
| from mindspore.ops.functional import vmap | |||
| from mindspore.common.parameter import Parameter | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_vmap_cond(): | |||
| """ | |||
| Feature: vmap | |||
| Description: This case mainly tests the following `vmap` application scenarios in graph mode: | |||
| 1. The `fn` is a `Cell`, which contains control flow operators, such as `if` and `while`. | |||
| 2. The specific VmapRule of `Switch` and `Add` operation. | |||
| 3. The `in_axes` is a single integer, which automatically match to multiple arguments. | |||
| Expectation: success | |||
| """ | |||
| class CondNet(nn.Cell): | |||
| def __init__(self): | |||
| super(CondNet, self).__init__() | |||
| self.inner_tensor_a = Tensor(2, mstype.int32) | |||
| self.inner_tensor_b = Tensor(5, mstype.int32) | |||
| def construct(self, x, y): | |||
| a = self.inner_tensor_a + 1 | |||
| b = self.inner_tensor_b | |||
| if a < b: | |||
| b += a | |||
| else: | |||
| b -= a | |||
| b += 5 | |||
| i = 0 | |||
| while i < 4: | |||
| x += 1 | |||
| i += 1 | |||
| out = b + x + y | |||
| return out | |||
| x_hat = Tensor([2, 3, 1], mstype.int32) | |||
| y_hat = Tensor([5, 4, 3], mstype.int32) | |||
| result = vmap(CondNet(), 0, 0)(x_hat, y_hat) | |||
| expect_result = Tensor([24, 24, 21], mstype.int32) | |||
| assert np.allclose(result.asnumpy(), expect_result.asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_vmap_gradient(): | |||
| """ | |||
| Feature: vmap | |||
| Description: This case mainly tests the following `vmap` application scenarios in graph mode: | |||
| 1. `vmap` and `grad` are used in combination. | |||
| 2. `vmap` and `jvp` are used in combination. | |||
| Expectation: success | |||
| """ | |||
| def forward_fn(x, y): | |||
| out = x + 2 * y | |||
| out = F.sin(out) | |||
| return F.reduce_sum(out) | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, fn): | |||
| super(GradNet, self).__init__() | |||
| self.fn = fn | |||
| def construct(self, x, y): | |||
| out = F.grad(self.fn, grad_position=(0, 1))(x, y) | |||
| return out | |||
| def vmap_fn(x, y): | |||
| output = vmap(forward_fn, 1, 0)(x, y) | |||
| return F.reduce_sum(output) | |||
| def jvp_fn(x, y, v): | |||
| out = F.jvp(forward_fn, (x, y), (v, v)) | |||
| return out | |||
| x_hat = Tensor([[1., 2., 3.], [2., 3., 4.]], mstype.float32) | |||
| y_hat = Tensor([[2., 3., 4.], [3., 4., 5.]], mstype.float32) | |||
| expect_x_grad = Tensor([[0.28366217, -0.14550003, 0.0044257], | |||
| [-0.14550003, 0.0044257, 0.13673723]], mstype.float32) | |||
| expect_y_grad = Tensor([[0.56732434, -0.29100007, 0.0088514], | |||
| [-0.29100007, 0.0088514, 0.27347445]], mstype.float32) | |||
| vmap_grad_x, vmap_grad_y = vmap(GradNet(forward_fn), 1, 1)(x_hat, y_hat) | |||
| assert np.allclose(vmap_grad_x.asnumpy(), expect_x_grad.asnumpy(), 0.0001, 0.0001) | |||
| assert np.allclose(vmap_grad_y.asnumpy(), expect_y_grad.asnumpy(), 0.0001, 0.0001) | |||
| grad_vmap_x, grad_vmap_y = GradNet(vmap_fn)(x_hat, y_hat) | |||
| assert np.allclose(grad_vmap_x.asnumpy(), expect_x_grad.asnumpy(), 0.0001, 0.0001) | |||
| assert np.allclose(grad_vmap_y.asnumpy(), expect_y_grad.asnumpy(), 0.0001, 0.0001) | |||
| x_hat = Tensor(np.array([[1.], [2.], [3.]]), mstype.float32) | |||
| y_hat = Tensor(np.array([[1.], [2.], [3.]]), mstype.float32) | |||
| v_hat = Tensor(np.array([[1.], [2.], [3.]]), mstype.float32) | |||
| vmap_jvp_x, vmap_jvp_y = vmap(jvp_fn, 0, 0)(x_hat, y_hat, v_hat) | |||
| expect_x_jvp = Tensor([0.141120002, -0.279415488, 0.412118465], mstype.float32) | |||
| expect_y_jvp = Tensor([-2.96997738, 5.76102161, -8.20017242], mstype.float32) | |||
| assert np.allclose(vmap_jvp_x.asnumpy(), expect_x_jvp.asnumpy(), 0.0001, 0.0001) | |||
| assert np.allclose(vmap_jvp_y.asnumpy(), expect_y_jvp.asnumpy(), 0.0001, 0.0001) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_vmap_monad(): | |||
| """ | |||
| Feature: vmap | |||
| Description: This case mainly tests the following `vmap` application scenarios in graph mode: | |||
| 1. The `fn` is a `Cell`, which contains side effect operators, such as `AssignAdd`, `Assign`, | |||
| `Print`, `ScatterAdd`. | |||
| 2. Parameter as argument. | |||
| Expectation: success | |||
| """ | |||
| class AssignNet(nn.Cell): | |||
| def __init__(self): | |||
| super(AssignNet, self).__init__() | |||
| self.assign = P.Assign() | |||
| self.assign_add = P.AssignAdd() | |||
| self.scatter_add = P.ScatterAdd() | |||
| self.assign_ref = Parameter(Tensor([[0, 0, 0], [1, 1, 1]], mstype.float32), name='assign_ref') | |||
| self.replace_tensor = Tensor([[1, 1, 1], [2, 2, 2]], mstype.float32) | |||
| def construct(self, assign_add_val, assign_add_var, scatter_ref, indices, updates): | |||
| self.assign(self.assign_ref, self.replace_tensor) | |||
| F.print(self.assign_ref) | |||
| out = self.assign_add(assign_add_var, assign_add_val) + self.scatter_add(scatter_ref, indices, updates) | |||
| return out | |||
| class VmapMonadNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(VmapMonadNet, self).__init__() | |||
| self.net = net | |||
| self.assign_add_var = Parameter( | |||
| Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[2, 2, 2], [2, 2, 2], [2, 2, 2]]], mstype.float32), | |||
| name='assign_add_var') | |||
| self.scatter_ref = Parameter( | |||
| Tensor([[[0, 0, 0], [0, 0, 0]], [[1, 1, 1], [1, 1, 1]], [[2, 2, 2], [2, 2, 2]]], mstype.float32), | |||
| name='scatter_ref') | |||
| def construct(self, assign_add_val, scatter_indices, scatter_updates): | |||
| output = vmap(self.net, (0, 1, 0, 0, None), 1)(assign_add_val, self.assign_add_var, | |||
| self.scatter_ref, scatter_indices, scatter_updates) | |||
| return output, self.assign_add_var | |||
| assign_add_val = Tensor([[[1, 1, 1], [2, 2, 2]], [[1, 1, 1], [2, 2, 2]], [[1, 1, 1], [2, 2, 2]]], mstype.float32) | |||
| scatter_indices = Tensor([[[0, 1], [1, 1]], [[0, 1], [0, 1]], [[1, 1], [1, 0]]], mstype.int32) | |||
| scatter_updates = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]], mstype.int32) | |||
| output, assign_add_var = VmapMonadNet(AssignNet())(assign_add_val, scatter_indices, scatter_updates) | |||
| expect_output = Tensor([[[3, 3, 3], [7, 7, 7], [8, 8, 8]], [[13, 13, 13], [11, 11, 11], [12, 12, 12]]], | |||
| mstype.float32) | |||
| expect_assign_add_var = Tensor([[[2, 2, 2], [2, 2, 2], [2, 2, 2]], [[4, 4, 4], [4, 4, 4], [4, 4, 4]]], | |||
| mstype.float32) | |||
| assert np.allclose(output.asnumpy(), expect_output.asnumpy()) | |||
| assert np.allclose(assign_add_var.asnumpy(), expect_assign_add_var.asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_vmap_reduce(): | |||
| """ | |||
| Feature: vmap | |||
| Description: This case mainly tests the following `vmap` application scenarios in graph mode: | |||
| 1. The specific VmapRule of `ReduceSum` operation. | |||
| 2. The `out_axes` is a single integer, which automatically match to multiple outputs. | |||
| Expectation: success | |||
| """ | |||
| class ReduceNet(nn.Cell): | |||
| def __init__(self): | |||
| super(ReduceNet, self).__init__() | |||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||
| self.reduce_sum_keep_dims = P.ReduceSum(keep_dims=True) | |||
| def construct(self, x): | |||
| out1 = self.reduce_sum(x) | |||
| out2 = self.reduce_sum_keep_dims(x) | |||
| out3 = self.reduce_sum(x, 1) | |||
| out4 = self.reduce_sum_keep_dims(x, 1) | |||
| out5 = self.reduce_sum(x, (0, 1)) | |||
| out6 = self.reduce_sum_keep_dims(x, (0, 1)) | |||
| output = (out1, out2, out3, out4, out5, out6) | |||
| return output | |||
| class VmapNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(VmapNet, self).__init__() | |||
| self.net = net | |||
| def construct(self, x): | |||
| vmap_function = F.vmap(self.net, 1, 0) | |||
| output = vmap_function(x) | |||
| return output | |||
| x_hat = Tensor(np.array([[[[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3]], | |||
| [[4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]], | |||
| [[7, 7, 7, 7, 7, 7], [8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9]]], | |||
| [[[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3]], | |||
| [[4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]], | |||
| [[7, 7, 7, 7, 7, 7], [8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9]]], | |||
| [[[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3]], | |||
| [[4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]], | |||
| [[7, 7, 7, 7, 7, 7], [8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9]]]]), mstype.float32) | |||
| result1, result2, result3, result4, result5, result6 = VmapNet(ReduceNet())(x_hat) | |||
| expect_result1 = Tensor([108, 270, 432], mstype.float32) | |||
| assert np.allclose(result1.asnumpy(), expect_result1.asnumpy()) | |||
| expect_result2 = Tensor([[[[108]]], [[[270]]], [[[432]]]], mstype.float32) | |||
| assert np.allclose(result2.asnumpy(), expect_result2.asnumpy()) | |||
| expect_result3 = Tensor([[[6, 6, 6, 6, 6, 6], [6, 6, 6, 6, 6, 6], [6, 6, 6, 6, 6, 6]], | |||
| [[15, 15, 15, 15, 15, 15], [15, 15, 15, 15, 15, 15], [15, 15, 15, 15, 15, 15]], | |||
| [[24, 24, 24, 24, 24, 24], [24, 24, 24, 24, 24, 24], [24, 24, 24, 24, 24, 24]]], | |||
| mstype.float32) | |||
| assert np.allclose(result3.asnumpy(), expect_result3.asnumpy()) | |||
| expect_result4 = Tensor([[[[6, 6, 6, 6, 6, 6]], [[6, 6, 6, 6, 6, 6]], [[6, 6, 6, 6, 6, 6]]], | |||
| [[[15, 15, 15, 15, 15, 15]], [[15, 15, 15, 15, 15, 15]], [[15, 15, 15, 15, 15, 15]]], | |||
| [[[24, 24, 24, 24, 24, 24]], [[24, 24, 24, 24, 24, 24]], [[24, 24, 24, 24, 24, 24]]]], | |||
| mstype.float32) | |||
| assert np.allclose(result4.asnumpy(), expect_result4.asnumpy()) | |||
| expect_result5 = Tensor([[18, 18, 18, 18, 18, 18], [45, 45, 45, 45, 45, 45], [72, 72, 72, 72, 72, 72]], | |||
| mstype.float32) | |||
| assert np.allclose(result5.asnumpy(), expect_result5.asnumpy()) | |||
| expect_result6 = Tensor([[[[18, 18, 18, 18, 18, 18]]], [[[45, 45, 45, 45, 45, 45]]], [[[72, 72, 72, 72, 72, 72]]]], | |||
| mstype.float32) | |||
| assert np.allclose(result6.asnumpy(), expect_result6.asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_vmap_general_rule(): | |||
| """ | |||
| Feature: vmap | |||
| Description: This case mainly tests the following `vmap` application scenarios in graph mode: | |||
| 1. The general VmapRule. | |||
| 2. The specific VmapRule of `Reshape` operation. | |||
| 3. The same `vmap` object is called multiple times. | |||
| 4. The `mindspore.numpy` objects as the arguments. | |||
| Expectation: success | |||
| """ | |||
| def convolve(x, w): | |||
| output = [] | |||
| for i in range(1, len(x) - 1): | |||
| output.append(mnp.dot(x[i - 1 : i + 2], w)) | |||
| return mnp.stack(output) | |||
| x = mnp.arange(5).astype('float32') | |||
| w = mnp.array([1., 2., 3.]) | |||
| vmap_function = vmap(convolve) | |||
| x1 = mnp.stack([x, x, x]) | |||
| w1 = mnp.stack([w, w, w]) | |||
| result1 = vmap_function(x1, w1) | |||
| expect_result1 = Tensor([[8, 14, 20], [8, 14, 20], [8, 14, 20]], mstype.float32) | |||
| assert np.allclose(result1.asnumpy(), expect_result1.asnumpy()) | |||
| x2 = mnp.stack([x, x + 1, x + 2]) | |||
| w2 = mnp.stack([w, w * 2, w * 3]) | |||
| result2 = vmap_function(x2, w2) | |||
| expect_result2 = Tensor([[8, 14, 20], [28, 40, 52], [60, 78, 96]], mstype.float32) | |||
| assert np.allclose(result2.asnumpy(), expect_result2.asnumpy()) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_vmap_nested_axes(): | |||
| """ | |||
| Feature: vmap | |||
| Description: This case mainly tests the following `vmap` application scenarios in graph mode: | |||
| 1. The nested inputs as the vmap's arguments. | |||
| 2. One element of the `in_axes` is a minus integer. | |||
| 3. Some outputs of the function is scalars with destination axis non-None. | |||
| 4. The `in_axes` is nested Tuple and List. | |||
| 5. VmapRule for that operators with indefinite length as input, such as `Stack`. | |||
| Expectation: success | |||
| """ | |||
| class AddNet(nn.Cell): | |||
| def __init__(self): | |||
| super(AddNet, self).__init__() | |||
| self.inner_tensor = Tensor([5, 6], mstype.float32) | |||
| self.inner_para = Parameter(Tensor([5, 6], mstype.float32), name='inner_para') | |||
| def construct(self, x, y): | |||
| a = 1 | |||
| b = 2 | |||
| c = 3 | |||
| d = self.inner_tensor + a | |||
| e = F.stack((self.inner_para, self.inner_para)) | |||
| return ((a, b), c), d, e | |||
| x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| z_hat = 1 | |||
| ((res1, res2), res3), res4, res5 = \ | |||
| vmap(AddNet(), in_axes=(1, [-1, None]), out_axes=((0, None), 0, None))(x_hat, (y_hat, z_hat)) | |||
| expect_res1 = Tensor([1, 1, 1], mstype.float32) | |||
| expect_res2 = Tensor([2, 2, 2], mstype.float32) | |||
| expect_res3 = 3 | |||
| expect_res4 = Tensor([[6, 7], [6, 7], [6, 7]], mstype.float32) | |||
| expect_res5 = Tensor([[5, 6], [5, 6]], mstype.float32) | |||
| assert np.allclose(res1.asnumpy(), expect_res1.asnumpy()) | |||
| assert np.allclose(res2.asnumpy(), expect_res2.asnumpy()) | |||
| assert res3 == expect_res3 | |||
| assert np.allclose(res4.asnumpy(), expect_res4.asnumpy()) | |||
| assert np.allclose(res5.asnumpy(), expect_res5.asnumpy()) | |||
| @@ -0,0 +1,69 @@ | |||
| # Copyright 2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test vmap in pynative mode""" | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.context as context | |||
| import mindspore.ops.functional as F | |||
| from mindspore import dtype as mstype | |||
| from mindspore.common import Tensor | |||
| from mindspore.ops.functional import vmap | |||
| from mindspore.common.api import ms_function | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_vmap_nested(): | |||
| """ | |||
| Feature: vmap | |||
| Description: This case mainly tests the following `vmap` application scenarios in PyNative mode: | |||
| 1.Calling nested `vmap` functions. | |||
| 2.`fn` is a function wrapped `ms_function`. | |||
| 3.Function contains free variables. | |||
| Expectation: success | |||
| """ | |||
| outter_tensor = Tensor([1], mstype.float32) | |||
| def add_fn(x): | |||
| return F.add(x, outter_tensor) | |||
| @ms_function | |||
| def inner_vmap_fn(x, outter_tensor): | |||
| vmap_funtion = vmap(add_fn, 1) | |||
| out = vmap_funtion(x) | |||
| output = out + outter_tensor | |||
| return output | |||
| def outter_vmap_fn(x): | |||
| output = vmap(inner_vmap_fn, (0, None), 1)(x, outter_tensor) | |||
| return output | |||
| x_hat = Tensor([[[1., 2., 3.], [4., 5., 6.]], | |||
| [[2., 3., 4.], [5., 6., 7.]], | |||
| [[3., 4., 5.], [6., 7., 8.]], | |||
| [[4., 5., 6.], [7., 8., 9.]]], mstype.float32) | |||
| result = outter_vmap_fn(x_hat) | |||
| expect_result = Tensor([[[3., 6.], [4., 7.], [5., 8.], [6., 9.]], | |||
| [[4., 7.], [5., 8.], [6., 9.], [7., 10.]], | |||
| [[5., 8.], [6., 9.], [7., 10.], [8., 11.]]], mstype.float32) | |||
| assert np.allclose(result.asnumpy(), expect_result.asnumpy()) | |||
| @@ -0,0 +1,220 @@ | |||
| # Copyright 2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """test vmap in graph mode""" | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| import mindspore.context as context | |||
| import mindspore.ops.operations as P | |||
| from mindspore import Tensor | |||
| from mindspore import dtype as mstype | |||
| from mindspore.ops.functional import vmap | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| class ThreeInputsTwoOutputsNet(nn.Cell): | |||
| def construct(self, x, y, z): | |||
| return x + y, z | |||
| def test_lambda_fn(): | |||
| """ | |||
| Feature: vmap | |||
| Description: The first argument of `vmap` is a lambda function. | |||
| Expectation: throw TypeError:"Parse Lambda Function Fail. Node type must be Lambda, but got Call." | |||
| """ | |||
| x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| z_hat = 1 | |||
| with pytest.raises(TypeError) as ex: | |||
| vmap(lambda x, y, z: x + y + z, in_axes=(1, 1, None), out_axes=0)(x_hat, y_hat, z_hat) | |||
| assert "Parse Lambda Function Fail. Node type must be Lambda, but got Call." in str(ex.value) | |||
| def test_single_op(): | |||
| """ | |||
| Feature: vmap | |||
| Description: The first argument of `vmap` is a single primitive. | |||
| Expectation: throw RuntimeError:"'VmapOperation' arg0 Prim: S-Prim-Add cast to 'FuncGraphAbstractClosure' failed." | |||
| """ | |||
| x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| with pytest.raises(RuntimeError) as ex: | |||
| vmap(P.Add(), in_axes=(1, 1), out_axes=0)(x_hat, y_hat) | |||
| assert "'VmapOperation' arg0 Prim: S-Prim-Add cast to 'FuncGraphAbstractClosure' failed." in str(ex.value) | |||
| def test_none_in_axes(): | |||
| """ | |||
| Feature: vmap | |||
| Description: The `in_axis` argument of `vmap` is a single None, and it's invalid when apply `vmap`. | |||
| Expectation: throw RuntimeError:"The 'in_axes' of 'vmap' cannot be a single None." | |||
| """ | |||
| x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| z_hat = 1 | |||
| with pytest.raises(RuntimeError) as ex: | |||
| vmap(ThreeInputsTwoOutputsNet(), in_axes=None, out_axes=0)(x_hat, y_hat, z_hat) | |||
| assert "The 'in_axes' of 'vmap' cannot be a single None." in str(ex.value) | |||
| def test_none_out_axes(): | |||
| """ | |||
| Feature: vmap | |||
| Description: The `out_axes` argument of `vmap` is a nested None, and it's invalid when apply `vmap`. | |||
| Expectation: throw RuntimeError:"The 'out_axes' of 'vmap' cannot be all None, but got | |||
| (None, None, None, (None, None))." | |||
| """ | |||
| x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| z_hat = 1 | |||
| with pytest.raises(RuntimeError) as ex: | |||
| vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), | |||
| out_axes=(None, None, None, (None, None)))(x_hat, y_hat, z_hat) | |||
| assert "The 'out_axes' of 'vmap' cannot be all None, but got (None, None, None, (None, None))." in str(ex.value) | |||
| def test_mismatch_out_axes(): | |||
| """ | |||
| Feature: vmap | |||
| Description: The `out_axes` of `vmap` sets to (0, 0, 0), but the outputs of `fn` is x + y, z. | |||
| Expectation: throw RuntimeError:"The size of vmap's 'out_axes' should be equal to the number of results of 'fn': 2, | |||
| but got size: 3." | |||
| """ | |||
| x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| z_hat = 1 | |||
| with pytest.raises(RuntimeError) as ex: | |||
| vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=(0, 0, 0))(x_hat, y_hat, z_hat) | |||
| assert "The size of vmap's 'out_axes' should be equal to the number of results of 'fn': 2, but got size: 3." \ | |||
| in str(ex.value) | |||
| def test_axis_type(): | |||
| """ | |||
| Feature: vmap | |||
| Description: The `in_axes` of `vmap` contains elements of Float type. | |||
| Expectation: throw RuntimeError:"The axis in vmap's 'in_axes' should be a None or a scalar of type Int64Imm, | |||
| but got a 1." | |||
| """ | |||
| x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| z_hat = 1 | |||
| with pytest.raises(RuntimeError) as ex: | |||
| vmap(ThreeInputsTwoOutputsNet(), in_axes=(1., 1., None), out_axes=0)(x_hat, y_hat, z_hat) | |||
| assert "The axis in vmap's 'in_axes' should be a None or a scalar of type Int64Imm, but got a 1." in str(ex.value) | |||
| def test_axis_out_of_bounds(): | |||
| """ | |||
| Feature: vmap | |||
| Description: The dimension of X is 2, but the corresponding axis -3 is set. | |||
| Expectation: throw RuntimeError:"The axis: -3 in 'in_axes' is out of bounds for array of dimension [-2,2)." | |||
| """ | |||
| x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| z_hat = 1 | |||
| with pytest.raises(RuntimeError) as ex: | |||
| vmap(ThreeInputsTwoOutputsNet(), in_axes=(-3, 2, None), out_axes=0)(x_hat, y_hat, z_hat) | |||
| assert "The axis: -3 in 'in_axes' is out of bounds for array of dimension [-2,2)." in str(ex.value) | |||
| def test_mismatch_none_axis(): | |||
| """ | |||
| Feature: vmap | |||
| Description: The source axis of the first output of `fn` is non-None, but the `out_axes` for that is None, | |||
| it's invalid when apply `vmap`. | |||
| Expectation: throw RuntimeError:"It is invalid that source is not None and dst is None." | |||
| """ | |||
| x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| z_hat = 1 | |||
| with pytest.raises(RuntimeError) as ex: | |||
| vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=(None, 0))(x_hat, y_hat, z_hat) | |||
| assert "It is invalid that source is not None and dst is None." in str(ex.value) | |||
| def test_mismatch_parameters_number(): | |||
| """ | |||
| Feature: vmap | |||
| Description: The arguments of the cell is (x, y, z), but the arguments of vmap-ed function is (x_hat, y_hat). | |||
| Expectation: throw TypeError:"The parameters number of the function is 3, but the number of provided arguments | |||
| is 2." | |||
| """ | |||
| x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| with pytest.raises(TypeError) as ex: | |||
| vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=0)(x_hat, y_hat) | |||
| assert "The parameters number of the function is 3, but the number of provided arguments is 2." in str(ex.value) | |||
| def test_mismatch_axis_size(): | |||
| """ | |||
| Feature: vmap | |||
| Description: The `axis_size` of X is 3, and the `axis_size` of Y is 2, they are not equal, vmap needs to ensure | |||
| that the `axis_size` of all parameters are uniform. | |||
| Expectation: throw RuntimeError:"The 'axis_size' of each argument in the scope of 'vmap' should be equal, | |||
| but got 3 and 2." | |||
| """ | |||
| x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| z_hat = 1 | |||
| with pytest.raises(RuntimeError) as ex: | |||
| vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 0, None), out_axes=0)(x_hat, y_hat, z_hat) | |||
| assert "The 'axis_size' of each argument in the scope of 'vmap' should be equal, but got 3 and 2." in str(ex.value) | |||
| def test_vmap_non_input(): | |||
| """ | |||
| Feature: vmap | |||
| Description: The arguments of the cell is empty, it's invalid when apply `vmap`. | |||
| Expectation: throw RuntimeError:"Failed to get 'axis_size' within the scope of vmap." | |||
| """ | |||
| class NonInputSingleOutputNet(nn.Cell): | |||
| def construct(self): | |||
| return 1 | |||
| with pytest.raises(RuntimeError) as ex: | |||
| vmap(NonInputSingleOutputNet())() | |||
| assert "Failed to get 'axis_size' within the scope of vmap." in str(ex.value) | |||
| def test_non_fn(): | |||
| """ | |||
| Feature: vmap | |||
| Description: The first argument of `vmap` not provided, which is required positional argument. | |||
| Expectation: throw TypeError:"vmap() missing 1 required positional argument: 'fn'" | |||
| """ | |||
| x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| z_hat = 1 | |||
| with pytest.raises(TypeError) as ex: | |||
| vmap(in_axes=(1, 1, None), out_axes=0)(x_hat, y_hat, z_hat) | |||
| assert "vmap() missing 1 required positional argument: 'fn'" in str(ex.value) | |||
| def test_scalar_with_non_zero_axis(): | |||
| """ | |||
| Feature: vmap | |||
| Description: The second output of `fn` is a scalar with source axis None, but get a destination axis 1, and it's | |||
| invalid when apply `vmap`. | |||
| Expectation: throw RuntimeError:"The axis: 1 in 'out_axes' is out of bounds for array of dimension [-1,1)." | |||
| """ | |||
| x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) | |||
| z_hat = 1 | |||
| with pytest.raises(RuntimeError) as ex: | |||
| vmap(ThreeInputsTwoOutputsNet(), in_axes=(1, 1, None), out_axes=(0, 1))(x_hat, y_hat, z_hat) | |||
| assert "The axis: 1 in 'out_axes' is out of bounds for array of dimension [-1,1)." in str(ex.value) | |||