# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """test vmap in graph mode""" import pytest import numpy as np import mindspore.nn as nn import mindspore.numpy as mnp import mindspore.context as context import mindspore.ops.operations as P import mindspore.ops.functional as F from mindspore import dtype as mstype from mindspore.common import Tensor from mindspore.ops.functional import vmap from mindspore.common.parameter import Parameter context.set_context(mode=context.GRAPH_MODE) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_vmap_cond(): """ Feature: vmap Description: This case mainly tests the following `vmap` application scenarios in graph mode: 1. The `fn` is a `Cell`, which contains control flow operators, such as `if` and `while`. 2. The specific VmapRule of `Switch` and `Add` operation. 3. The `in_axes` is a single integer, which automatically match to multiple arguments. Expectation: success """ class CondNet(nn.Cell): def __init__(self): super(CondNet, self).__init__() self.inner_tensor_a = Tensor(2, mstype.int32) self.inner_tensor_b = Tensor(5, mstype.int32) def construct(self, x, y): a = self.inner_tensor_a + 1 b = self.inner_tensor_b if a < b: b += a else: b -= a b += 5 i = 0 while i < 4: x += 1 i += 1 out = b + x + y return out x_hat = Tensor([2, 3, 1], mstype.int32) y_hat = Tensor([5, 4, 3], mstype.int32) result = vmap(CondNet(), 0, 0)(x_hat, y_hat) expect_result = Tensor([24, 24, 21], mstype.int32) assert np.allclose(result.asnumpy(), expect_result.asnumpy()) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_vmap_gradient(): """ Feature: vmap Description: This case mainly tests the following `vmap` application scenarios in graph mode: 1. `vmap` and `grad` are used in combination. 2. `vmap` and `jvp` are used in combination. Expectation: success """ def forward_fn(x, y): out = x + 2 * y out = F.sin(out) return F.reduce_sum(out) class GradNet(nn.Cell): def __init__(self, fn): super(GradNet, self).__init__() self.fn = fn def construct(self, x, y): out = F.grad(self.fn, grad_position=(0, 1))(x, y) return out def vmap_fn(x, y): output = vmap(forward_fn, 1, 0)(x, y) return F.reduce_sum(output) def jvp_fn(x, y, v): out = F.jvp(forward_fn, (x, y), (v, v)) return out x_hat = Tensor([[1., 2., 3.], [2., 3., 4.]], mstype.float32) y_hat = Tensor([[2., 3., 4.], [3., 4., 5.]], mstype.float32) expect_x_grad = Tensor([[0.28366217, -0.14550003, 0.0044257], [-0.14550003, 0.0044257, 0.13673723]], mstype.float32) expect_y_grad = Tensor([[0.56732434, -0.29100007, 0.0088514], [-0.29100007, 0.0088514, 0.27347445]], mstype.float32) vmap_grad_x, vmap_grad_y = vmap(GradNet(forward_fn), 1, 1)(x_hat, y_hat) assert np.allclose(vmap_grad_x.asnumpy(), expect_x_grad.asnumpy(), 0.0001, 0.0001) assert np.allclose(vmap_grad_y.asnumpy(), expect_y_grad.asnumpy(), 0.0001, 0.0001) grad_vmap_x, grad_vmap_y = GradNet(vmap_fn)(x_hat, y_hat) assert np.allclose(grad_vmap_x.asnumpy(), expect_x_grad.asnumpy(), 0.0001, 0.0001) assert np.allclose(grad_vmap_y.asnumpy(), expect_y_grad.asnumpy(), 0.0001, 0.0001) x_hat = Tensor(np.array([[1.], [2.], [3.]]), mstype.float32) y_hat = Tensor(np.array([[1.], [2.], [3.]]), mstype.float32) v_hat = Tensor(np.array([[1.], [2.], [3.]]), mstype.float32) vmap_jvp_x, vmap_jvp_y = vmap(jvp_fn, 0, 0)(x_hat, y_hat, v_hat) expect_x_jvp = Tensor([0.141120002, -0.279415488, 0.412118465], mstype.float32) expect_y_jvp = Tensor([-2.96997738, 5.76102161, -8.20017242], mstype.float32) assert np.allclose(vmap_jvp_x.asnumpy(), expect_x_jvp.asnumpy(), 0.0001, 0.0001) assert np.allclose(vmap_jvp_y.asnumpy(), expect_y_jvp.asnumpy(), 0.0001, 0.0001) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_vmap_monad(): """ Feature: vmap Description: This case mainly tests the following `vmap` application scenarios in graph mode: 1. The `fn` is a `Cell`, which contains side effect operators, such as `AssignAdd`, `Assign`, `Print`, `ScatterAdd`. 2. Parameter as argument. Expectation: success """ class AssignNet(nn.Cell): def __init__(self): super(AssignNet, self).__init__() self.assign = P.Assign() self.assign_add = P.AssignAdd() self.scatter_add = P.ScatterAdd() self.assign_ref = Parameter(Tensor([[0, 0, 0], [1, 1, 1]], mstype.float32), name='assign_ref') self.replace_tensor = Tensor([[1, 1, 1], [2, 2, 2]], mstype.float32) def construct(self, assign_add_val, assign_add_var, scatter_ref, indices, updates): self.assign(self.assign_ref, self.replace_tensor) F.print(self.assign_ref) out = self.assign_add(assign_add_var, assign_add_val) + self.scatter_add(scatter_ref, indices, updates) return out class VmapMonadNet(nn.Cell): def __init__(self, net): super(VmapMonadNet, self).__init__() self.net = net self.assign_add_var = Parameter( Tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[2, 2, 2], [2, 2, 2], [2, 2, 2]]], mstype.float32), name='assign_add_var') self.scatter_ref = Parameter( Tensor([[[0, 0, 0], [0, 0, 0]], [[1, 1, 1], [1, 1, 1]], [[2, 2, 2], [2, 2, 2]]], mstype.float32), name='scatter_ref') def construct(self, assign_add_val, scatter_indices, scatter_updates): output = vmap(self.net, (0, 1, 0, 0, None), 1)(assign_add_val, self.assign_add_var, self.scatter_ref, scatter_indices, scatter_updates) return output, self.assign_add_var assign_add_val = Tensor([[[1, 1, 1], [2, 2, 2]], [[1, 1, 1], [2, 2, 2]], [[1, 1, 1], [2, 2, 2]]], mstype.float32) scatter_indices = Tensor([[[0, 1], [1, 1]], [[0, 1], [0, 1]], [[1, 1], [1, 0]]], mstype.int32) scatter_updates = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]], mstype.int32) output, assign_add_var = VmapMonadNet(AssignNet())(assign_add_val, scatter_indices, scatter_updates) expect_output = Tensor([[[3, 3, 3], [7, 7, 7], [8, 8, 8]], [[13, 13, 13], [11, 11, 11], [12, 12, 12]]], mstype.float32) expect_assign_add_var = Tensor([[[2, 2, 2], [2, 2, 2], [2, 2, 2]], [[4, 4, 4], [4, 4, 4], [4, 4, 4]]], mstype.float32) assert np.allclose(output.asnumpy(), expect_output.asnumpy()) assert np.allclose(assign_add_var.asnumpy(), expect_assign_add_var.asnumpy()) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_vmap_reduce(): """ Feature: vmap Description: This case mainly tests the following `vmap` application scenarios in graph mode: 1. The specific VmapRule of `ReduceSum` operation. 2. The `out_axes` is a single integer, which automatically match to multiple outputs. Expectation: success """ class ReduceNet(nn.Cell): def __init__(self): super(ReduceNet, self).__init__() self.reduce_sum = P.ReduceSum(keep_dims=False) self.reduce_sum_keep_dims = P.ReduceSum(keep_dims=True) def construct(self, x): out1 = self.reduce_sum(x) out2 = self.reduce_sum_keep_dims(x) out3 = self.reduce_sum(x, 1) out4 = self.reduce_sum_keep_dims(x, 1) out5 = self.reduce_sum(x, (0, 1)) out6 = self.reduce_sum_keep_dims(x, (0, 1)) output = (out1, out2, out3, out4, out5, out6) return output class VmapNet(nn.Cell): def __init__(self, net): super(VmapNet, self).__init__() self.net = net def construct(self, x): vmap_function = F.vmap(self.net, 1, 0) output = vmap_function(x) return output x_hat = Tensor(np.array([[[[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3]], [[4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]], [[7, 7, 7, 7, 7, 7], [8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9]]], [[[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3]], [[4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]], [[7, 7, 7, 7, 7, 7], [8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9]]], [[[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2], [3, 3, 3, 3, 3, 3]], [[4, 4, 4, 4, 4, 4], [5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]], [[7, 7, 7, 7, 7, 7], [8, 8, 8, 8, 8, 8], [9, 9, 9, 9, 9, 9]]]]), mstype.float32) result1, result2, result3, result4, result5, result6 = VmapNet(ReduceNet())(x_hat) expect_result1 = Tensor([108, 270, 432], mstype.float32) assert np.allclose(result1.asnumpy(), expect_result1.asnumpy()) expect_result2 = Tensor([[[[108]]], [[[270]]], [[[432]]]], mstype.float32) assert np.allclose(result2.asnumpy(), expect_result2.asnumpy()) expect_result3 = Tensor([[[6, 6, 6, 6, 6, 6], [6, 6, 6, 6, 6, 6], [6, 6, 6, 6, 6, 6]], [[15, 15, 15, 15, 15, 15], [15, 15, 15, 15, 15, 15], [15, 15, 15, 15, 15, 15]], [[24, 24, 24, 24, 24, 24], [24, 24, 24, 24, 24, 24], [24, 24, 24, 24, 24, 24]]], mstype.float32) assert np.allclose(result3.asnumpy(), expect_result3.asnumpy()) expect_result4 = Tensor([[[[6, 6, 6, 6, 6, 6]], [[6, 6, 6, 6, 6, 6]], [[6, 6, 6, 6, 6, 6]]], [[[15, 15, 15, 15, 15, 15]], [[15, 15, 15, 15, 15, 15]], [[15, 15, 15, 15, 15, 15]]], [[[24, 24, 24, 24, 24, 24]], [[24, 24, 24, 24, 24, 24]], [[24, 24, 24, 24, 24, 24]]]], mstype.float32) assert np.allclose(result4.asnumpy(), expect_result4.asnumpy()) expect_result5 = Tensor([[18, 18, 18, 18, 18, 18], [45, 45, 45, 45, 45, 45], [72, 72, 72, 72, 72, 72]], mstype.float32) assert np.allclose(result5.asnumpy(), expect_result5.asnumpy()) expect_result6 = Tensor([[[[18, 18, 18, 18, 18, 18]]], [[[45, 45, 45, 45, 45, 45]]], [[[72, 72, 72, 72, 72, 72]]]], mstype.float32) assert np.allclose(result6.asnumpy(), expect_result6.asnumpy()) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_vmap_general_rule(): """ Feature: vmap Description: This case mainly tests the following `vmap` application scenarios in graph mode: 1. The general VmapRule. 2. The specific VmapRule of `Reshape` operation. 3. The same `vmap` object is called multiple times. 4. The `mindspore.numpy` objects as the arguments. Expectation: success """ def convolve(x, w): output = [] for i in range(1, len(x) - 1): output.append(mnp.dot(x[i - 1 : i + 2], w)) return mnp.stack(output) x = mnp.arange(5).astype('float32') w = mnp.array([1., 2., 3.]) vmap_function = vmap(convolve) x1 = mnp.stack([x, x, x]) w1 = mnp.stack([w, w, w]) result1 = vmap_function(x1, w1) expect_result1 = Tensor([[8, 14, 20], [8, 14, 20], [8, 14, 20]], mstype.float32) assert np.allclose(result1.asnumpy(), expect_result1.asnumpy()) x2 = mnp.stack([x, x + 1, x + 2]) w2 = mnp.stack([w, w * 2, w * 3]) result2 = vmap_function(x2, w2) expect_result2 = Tensor([[8, 14, 20], [28, 40, 52], [60, 78, 96]], mstype.float32) assert np.allclose(result2.asnumpy(), expect_result2.asnumpy()) @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard def test_vmap_nested_axes(): """ Feature: vmap Description: This case mainly tests the following `vmap` application scenarios in graph mode: 1. The nested inputs as the vmap's arguments. 2. One element of the `in_axes` is a minus integer. 3. Some outputs of the function is scalars with destination axis non-None. 4. The `in_axes` is nested Tuple and List. 5. VmapRule for that operators with indefinite length as input, such as `Stack`. Expectation: success """ class AddNet(nn.Cell): def __init__(self): super(AddNet, self).__init__() self.inner_tensor = Tensor([5, 6], mstype.float32) self.inner_para = Parameter(Tensor([5, 6], mstype.float32), name='inner_para') def construct(self, x, y): a = 1 b = 2 c = 3 d = self.inner_tensor + a e = F.stack((self.inner_para, self.inner_para)) return ((a, b), c), d, e x_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) y_hat = Tensor([[1, 2, 3], [4, 5, 6]], mstype.float32) z_hat = 1 ((res1, res2), res3), res4, res5 = \ vmap(AddNet(), in_axes=(1, [-1, None]), out_axes=((0, None), 0, None))(x_hat, (y_hat, z_hat)) expect_res1 = Tensor([1, 1, 1], mstype.float32) expect_res2 = Tensor([2, 2, 2], mstype.float32) expect_res3 = 3 expect_res4 = Tensor([[6, 7], [6, 7], [6, 7]], mstype.float32) expect_res5 = Tensor([[5, 6], [5, 6]], mstype.float32) assert np.allclose(res1.asnumpy(), expect_res1.asnumpy()) assert np.allclose(res2.asnumpy(), expect_res2.asnumpy()) assert res3 == expect_res3 assert np.allclose(res4.asnumpy(), expect_res4.asnumpy()) assert np.allclose(res5.asnumpy(), expect_res5.asnumpy())