Browse Source

!20520 [AutoParallel]Add op AllToAllv

Merge pull request !20520 from lichen/add_op_AllToAllv
tags/v1.4.0
i-robot Gitee 4 years ago
parent
commit
a7d40fc220
8 changed files with 257 additions and 1 deletions
  1. +1
    -0
      mindspore/core/base/core_ops.h
  2. +72
    -0
      mindspore/core/ops/alltoallv.cc
  3. +43
    -0
      mindspore/core/ops/alltoallv.h
  4. +1
    -1
      mindspore/nn/wrap/cell_wrapper.py
  5. +1
    -0
      mindspore/ops/_grad_experimental/__init__.py
  6. +33
    -0
      mindspore/ops/_grad_experimental/grad_comm_ops.py
  7. +26
    -0
      mindspore/ops/operations/_inner_ops.py
  8. +80
    -0
      tests/ut/python/parallel/test_alltoall_v.py

+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -383,6 +383,7 @@ inline const PrimitivePtr kPrimVirtualOutput = std::make_shared<Primitive>("_Vir
inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send");
inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("Receive");
inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce");
inline const PrimitivePtr kPrimAllToAllv = std::make_shared<Primitive>("AllToAllv");
inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap");
inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast");
inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGather");


+ 72
- 0
mindspore/core/ops/alltoallv.cc View File

@@ -0,0 +1,72 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ops/alltoallv.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"

namespace mindspore {
namespace ops {
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input_numbers", input_args.size(), kEqual, 1, prim_name);
CheckAndConvertUtils::CheckArgs<abstract::AbstractTuple>(prim_name, input_args, 0);
auto recv_shapes = primitive->GetAttr(RecvShapes);
MS_EXCEPTION_IF_NULL(recv_shapes);
auto shapes_seq = recv_shapes->cast<ValueSequeuePtr>();
MS_EXCEPTION_IF_NULL(shapes_seq);
auto shapes_value = shapes_seq->value();
abstract::BaseShapePtrList base_shape_list;
for (auto &value : shapes_value) {
auto each_shape_value = value->cast<ValueSequeuePtr>();
MS_EXCEPTION_IF_NULL(each_shape_value);
std::vector<int64_t> each_shape = GetValue<std::vector<int64_t>>(each_shape_value);
BaseShapePtr base_shape = std::make_shared<abstract::Shape>(each_shape);
MS_EXCEPTION_IF_NULL(base_shape);
base_shape_list.push_back(base_shape);
}
return std::make_shared<abstract::TupleShape>(base_shape_list);
}

TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("AllToAllv infer", input_args.size(), kEqual, 1, prim_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto recv_shapes = primitive->GetAttr(RecvShapes);
MS_EXCEPTION_IF_NULL(recv_shapes);
auto shapes_seq = recv_shapes->cast<ValueSequeuePtr>();
MS_EXCEPTION_IF_NULL(shapes_seq);
auto shapes_value = shapes_seq->value();
auto out_num = shapes_value.size();
auto recv_type = primitive->GetAttr(RecvType)->cast<TypePtr>();
MS_EXCEPTION_IF_NULL(recv_type);
std::vector<TypePtr> type_vec(out_num, recv_type);
return std::make_shared<Tuple>(type_vec);
}

AbstractBasePtr AllToAllvInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto type = InferType(primitive, input_args);
auto shape = InferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}

REGISTER_PRIMITIVE_EVAL_IMPL(AllToAllv, prim::kPrimAllToAllv, AllToAllvInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

+ 43
- 0
mindspore/core/ops/alltoallv.h View File

@@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_OPS_ALLTOALLV_H_
#define MINDSPORE_CORE_OPS_ALLTOALLV_H_
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"

namespace mindspore {
namespace ops {
constexpr auto kNameAllToAllv = "AllToAllv";
constexpr auto RecvShapes = "recv_shapes";
constexpr auto RecvType = "recv_type";
class AllToAllv : public PrimitiveC {
public:
AllToAllv() : PrimitiveC(kNameAllToAllv) {}
~AllToAllv() = default;
MS_DECLARE_PARENT(AllToAllv, PrimitiveC);
void Init() {}
};
AbstractBasePtr AllToAllvInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimAllToAllPtr = std::shared_ptr<AllToAllv>;
} // namespace ops
} // namespace mindspore

#endif // MINDSPORE_CORE_OPS_ALLTOALLV_H_

+ 1
- 1
mindspore/nn/wrap/cell_wrapper.py View File

@@ -501,7 +501,7 @@ class PipelineCell(Cell):

Args:
network (Cell): The target network to wrap.
micro_size (Int): MicroBatch size.
micro_size (int): MicroBatch size.

Examples:
>>> net = Net()


+ 1
- 0
mindspore/ops/_grad_experimental/__init__.py View File

@@ -18,5 +18,6 @@ from .._grad.grad_base import get_bprop_fn
from . import grad_array_ops
from . import grad_inner_ops
from . import grad_nn_ops
from . import grad_comm_ops

__all__ = ['get_bprop_fn']

+ 33
- 0
mindspore/ops/_grad_experimental/grad_comm_ops.py View File

@@ -0,0 +1,33 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Generate bprop for comm ops"""
from .._grad.grad_base import bprop_getters
from ..operations._inner_ops import AllToAllv


@bprop_getters.register(AllToAllv)
def get_bprop_alltoallv(self):
"""Generate bprop for AllToAllv."""
group = self.group
send_rank_ids = self.recv_rank_ids
recv_rank_ids = self.send_rank_ids
recv_shapes = self.recv_shapes_backward
recv_type = self.recv_type
alltoallv_grad = AllToAllv(send_rank_ids, recv_rank_ids, recv_shapes, recv_shapes, recv_type, group)

def bprop(x, out, dout):
return (alltoallv_grad(dout),)
return bprop

+ 26
- 0
mindspore/ops/operations/_inner_ops.py View File

@@ -492,6 +492,32 @@ class Receive(PrimitiveWithInfer):
return self.dtype


class AllToAllv(Primitive):
"""
AlltoAllv is a collective operation.

AlltoAllv sends data from the local rank to ranks in the send_rank_ids, as while receive data from recv_rank_ids.

Args:
send_rank_ids (list): Ranks which the data is sent to.
recv_rank_ids (list): Ranks which the data is received from.
recv_shapes (list): Data shape which received from recv_rank_ids.
recv_shapes_backward (list): Data shape which received from send_rank_ids in the backward.
recv_type (type): Data type which received from recv_rank_ids
group (str):
"""

@prim_attr_register
def __init__(self, send_rank_ids, recv_rank_ids, recv_shapes, recv_shapes_backward, recv_type,
group=GlobalComm.WORLD_COMM_GROUP):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
self.send_rank_ids = send_rank_ids
self.recv_rank_ids = recv_rank_ids
self.recv_shapes = recv_shapes
self.recv_shapes_backward = recv_shapes_backward
self.recv_type = recv_type


class MatrixSetDiag(PrimitiveWithInfer):
r"""
Modifies the batched diagonal part of a batched tensor.


+ 80
- 0
tests/ut/python/parallel/test_alltoall_v.py View File

@@ -0,0 +1,80 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.context as context
from mindspore import Tensor, Parameter
import mindspore.nn as nn
from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, Momentum
from mindspore.ops import operations as P
from mindspore.ops.operations._inner_ops import AllToAllv


class MatMulNet(nn.Cell):
def __init__(self, weight1):
super(MatMulNet, self).__init__()
self.matmul = P.MatMul()
self.mul = P.Mul()
self.alltoallv = AllToAllv(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]),
recv_shapes_backward=([32, 32], [32, 16]), recv_type=ms.float32)
self.weight1 = Parameter(weight1, "w1")

def construct(self, x1, x2):
out = self.matmul(x1, x2)
out = self.mul(out, self.weight1)
out = self.alltoallv((out, x1))
return out[0]


class MatMulNet2(nn.Cell):
def __init__(self, weight1):
super(MatMulNet2, self).__init__()
self.matmul = P.MatMul()
self.mul = P.Mul()
self.alltoallv = AllToAllv(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]),
recv_shapes_backward=([32, 32],), recv_type=ms.float32)
self.weight1 = Parameter(weight1, "w1")

def construct(self, x1, x2):
out = self.matmul(x1, x2)
out = self.mul(out, self.weight1)
out = self.alltoallv((out,))
return out[0]


_w1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
_x1 = Tensor(np.ones([32, 16]), dtype=ms.float32)
_x2 = Tensor(np.ones([16, 32]), dtype=ms.float32)


def compile_net(net):
context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_train()
_executor.compile(train_net, _x1, _x2)


def test_AllToAllv_two_inputs():
context.set_auto_parallel_context(device_num=8, global_rank=0)
net = MatMulNet(_w1)
compile_net(net)


def test_AllToAllv_single_input():
context.set_auto_parallel_context(device_num=8, global_rank=0)
net = MatMulNet2(_w1)
compile_net(net)

Loading…
Cancel
Save