Merge pull request !20520 from lichen/add_op_AllToAllvtags/v1.4.0
| @@ -383,6 +383,7 @@ inline const PrimitivePtr kPrimVirtualOutput = std::make_shared<Primitive>("_Vir | |||
| inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send"); | |||
| inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("Receive"); | |||
| inline const PrimitivePtr kPrimAllReduce = std::make_shared<Primitive>("AllReduce"); | |||
| inline const PrimitivePtr kPrimAllToAllv = std::make_shared<Primitive>("AllToAllv"); | |||
| inline const PrimitivePtr kPrimAllSwap = std::make_shared<Primitive>("AllSwap"); | |||
| inline const PrimitivePtr kPrimBroadcast = std::make_shared<Primitive>("Broadcast"); | |||
| inline const PrimitivePtr kPrimAllGather = std::make_shared<Primitive>("AllGather"); | |||
| @@ -0,0 +1,72 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/alltoallv.h" | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("input_numbers", input_args.size(), kEqual, 1, prim_name); | |||
| CheckAndConvertUtils::CheckArgs<abstract::AbstractTuple>(prim_name, input_args, 0); | |||
| auto recv_shapes = primitive->GetAttr(RecvShapes); | |||
| MS_EXCEPTION_IF_NULL(recv_shapes); | |||
| auto shapes_seq = recv_shapes->cast<ValueSequeuePtr>(); | |||
| MS_EXCEPTION_IF_NULL(shapes_seq); | |||
| auto shapes_value = shapes_seq->value(); | |||
| abstract::BaseShapePtrList base_shape_list; | |||
| for (auto &value : shapes_value) { | |||
| auto each_shape_value = value->cast<ValueSequeuePtr>(); | |||
| MS_EXCEPTION_IF_NULL(each_shape_value); | |||
| std::vector<int64_t> each_shape = GetValue<std::vector<int64_t>>(each_shape_value); | |||
| BaseShapePtr base_shape = std::make_shared<abstract::Shape>(each_shape); | |||
| MS_EXCEPTION_IF_NULL(base_shape); | |||
| base_shape_list.push_back(base_shape); | |||
| } | |||
| return std::make_shared<abstract::TupleShape>(base_shape_list); | |||
| } | |||
| TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| auto prim_name = primitive->name(); | |||
| CheckAndConvertUtils::CheckInteger("AllToAllv infer", input_args.size(), kEqual, 1, prim_name); | |||
| MS_EXCEPTION_IF_NULL(input_args[0]); | |||
| auto recv_shapes = primitive->GetAttr(RecvShapes); | |||
| MS_EXCEPTION_IF_NULL(recv_shapes); | |||
| auto shapes_seq = recv_shapes->cast<ValueSequeuePtr>(); | |||
| MS_EXCEPTION_IF_NULL(shapes_seq); | |||
| auto shapes_value = shapes_seq->value(); | |||
| auto out_num = shapes_value.size(); | |||
| auto recv_type = primitive->GetAttr(RecvType)->cast<TypePtr>(); | |||
| MS_EXCEPTION_IF_NULL(recv_type); | |||
| std::vector<TypePtr> type_vec(out_num, recv_type); | |||
| return std::make_shared<Tuple>(type_vec); | |||
| } | |||
| AbstractBasePtr AllToAllvInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| auto type = InferType(primitive, input_args); | |||
| auto shape = InferShape(primitive, input_args); | |||
| return abstract::MakeAbstract(shape, type); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(AllToAllv, prim::kPrimAllToAllv, AllToAllvInfer, nullptr, true); | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_ALLTOALLV_H_ | |||
| #define MINDSPORE_CORE_OPS_ALLTOALLV_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameAllToAllv = "AllToAllv"; | |||
| constexpr auto RecvShapes = "recv_shapes"; | |||
| constexpr auto RecvType = "recv_type"; | |||
| class AllToAllv : public PrimitiveC { | |||
| public: | |||
| AllToAllv() : PrimitiveC(kNameAllToAllv) {} | |||
| ~AllToAllv() = default; | |||
| MS_DECLARE_PARENT(AllToAllv, PrimitiveC); | |||
| void Init() {} | |||
| }; | |||
| AbstractBasePtr AllToAllvInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimAllToAllPtr = std::shared_ptr<AllToAllv>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_ALLTOALLV_H_ | |||
| @@ -501,7 +501,7 @@ class PipelineCell(Cell): | |||
| Args: | |||
| network (Cell): The target network to wrap. | |||
| micro_size (Int): MicroBatch size. | |||
| micro_size (int): MicroBatch size. | |||
| Examples: | |||
| >>> net = Net() | |||
| @@ -18,5 +18,6 @@ from .._grad.grad_base import get_bprop_fn | |||
| from . import grad_array_ops | |||
| from . import grad_inner_ops | |||
| from . import grad_nn_ops | |||
| from . import grad_comm_ops | |||
| __all__ = ['get_bprop_fn'] | |||
| @@ -0,0 +1,33 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Generate bprop for comm ops""" | |||
| from .._grad.grad_base import bprop_getters | |||
| from ..operations._inner_ops import AllToAllv | |||
| @bprop_getters.register(AllToAllv) | |||
| def get_bprop_alltoallv(self): | |||
| """Generate bprop for AllToAllv.""" | |||
| group = self.group | |||
| send_rank_ids = self.recv_rank_ids | |||
| recv_rank_ids = self.send_rank_ids | |||
| recv_shapes = self.recv_shapes_backward | |||
| recv_type = self.recv_type | |||
| alltoallv_grad = AllToAllv(send_rank_ids, recv_rank_ids, recv_shapes, recv_shapes, recv_type, group) | |||
| def bprop(x, out, dout): | |||
| return (alltoallv_grad(dout),) | |||
| return bprop | |||
| @@ -492,6 +492,32 @@ class Receive(PrimitiveWithInfer): | |||
| return self.dtype | |||
| class AllToAllv(Primitive): | |||
| """ | |||
| AlltoAllv is a collective operation. | |||
| AlltoAllv sends data from the local rank to ranks in the send_rank_ids, as while receive data from recv_rank_ids. | |||
| Args: | |||
| send_rank_ids (list): Ranks which the data is sent to. | |||
| recv_rank_ids (list): Ranks which the data is received from. | |||
| recv_shapes (list): Data shape which received from recv_rank_ids. | |||
| recv_shapes_backward (list): Data shape which received from send_rank_ids in the backward. | |||
| recv_type (type): Data type which received from recv_rank_ids | |||
| group (str): | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, send_rank_ids, recv_rank_ids, recv_shapes, recv_shapes_backward, recv_type, | |||
| group=GlobalComm.WORLD_COMM_GROUP): | |||
| self.init_prim_io_names(inputs=['x'], outputs=['output']) | |||
| self.send_rank_ids = send_rank_ids | |||
| self.recv_rank_ids = recv_rank_ids | |||
| self.recv_shapes = recv_shapes | |||
| self.recv_shapes_backward = recv_shapes_backward | |||
| self.recv_type = recv_type | |||
| class MatrixSetDiag(PrimitiveWithInfer): | |||
| r""" | |||
| Modifies the batched diagonal part of a batched tensor. | |||
| @@ -0,0 +1,80 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.context as context | |||
| from mindspore import Tensor, Parameter | |||
| import mindspore.nn as nn | |||
| from mindspore.common.api import _executor | |||
| from mindspore.nn import TrainOneStepCell, Momentum | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.operations._inner_ops import AllToAllv | |||
| class MatMulNet(nn.Cell): | |||
| def __init__(self, weight1): | |||
| super(MatMulNet, self).__init__() | |||
| self.matmul = P.MatMul() | |||
| self.mul = P.Mul() | |||
| self.alltoallv = AllToAllv(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]), | |||
| recv_shapes_backward=([32, 32], [32, 16]), recv_type=ms.float32) | |||
| self.weight1 = Parameter(weight1, "w1") | |||
| def construct(self, x1, x2): | |||
| out = self.matmul(x1, x2) | |||
| out = self.mul(out, self.weight1) | |||
| out = self.alltoallv((out, x1)) | |||
| return out[0] | |||
| class MatMulNet2(nn.Cell): | |||
| def __init__(self, weight1): | |||
| super(MatMulNet2, self).__init__() | |||
| self.matmul = P.MatMul() | |||
| self.mul = P.Mul() | |||
| self.alltoallv = AllToAllv(send_rank_ids=[0], recv_rank_ids=[1, 2], recv_shapes=([32, 32], [32, 64]), | |||
| recv_shapes_backward=([32, 32],), recv_type=ms.float32) | |||
| self.weight1 = Parameter(weight1, "w1") | |||
| def construct(self, x1, x2): | |||
| out = self.matmul(x1, x2) | |||
| out = self.mul(out, self.weight1) | |||
| out = self.alltoallv((out,)) | |||
| return out[0] | |||
| _w1 = Tensor(np.ones([32, 32]), dtype=ms.float32) | |||
| _x1 = Tensor(np.ones([32, 16]), dtype=ms.float32) | |||
| _x2 = Tensor(np.ones([16, 32]), dtype=ms.float32) | |||
| def compile_net(net): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=False) | |||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| train_net = TrainOneStepCell(net, optimizer) | |||
| train_net.set_train() | |||
| _executor.compile(train_net, _x1, _x2) | |||
| def test_AllToAllv_two_inputs(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| net = MatMulNet(_w1) | |||
| compile_net(net) | |||
| def test_AllToAllv_single_input(): | |||
| context.set_auto_parallel_context(device_num=8, global_rank=0) | |||
| net = MatMulNet2(_w1) | |||
| compile_net(net) | |||