From bb5d4212f7aa34f84693cb5ae5dc4fab348ae5e2 Mon Sep 17 00:00:00 2001 From: Xiaoda Zhang Date: Thu, 22 Jul 2021 14:52:55 +0800 Subject: [PATCH] enable All2All in infering redistribution ops --- mindspore/ccsrc/frontend/parallel/context.cc | 3 + mindspore/ccsrc/frontend/parallel/context.h | 4 + .../ccsrc/frontend/parallel/step_parallel.cc | 2 +- .../redistribution_operator_infer.cc | 3 +- mindspore/ccsrc/pipeline/jit/init.cc | 2 + mindspore/parallel/_auto_parallel_context.py | 28 +++++- tests/ut/python/parallel/test_batchmm.py | 89 +++++++++++++++++++ 7 files changed, 126 insertions(+), 5 deletions(-) create mode 100644 tests/ut/python/parallel/test_batchmm.py diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index 0e303414bf..0b3264129a 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -72,6 +72,7 @@ void ParallelContext::Reset() { optimizer_weight_shard_size_ = -1; optimizer_weight_shard_aggregated_save_ = false; sharding_propagation_ = false; + enable_all2all_ = false; } void ParallelContext::set_device_num(int64_t device_num) { @@ -269,5 +270,7 @@ void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func } void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; } + +void ParallelContext::set_enable_all2all(const bool enable) { enable_all2all_ = enable; } } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h index ac475237df..081394bd4b 100644 --- a/mindspore/ccsrc/frontend/parallel/context.h +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -126,6 +126,8 @@ class ParallelContext { std::string communi_parallel_mode() const { return communi_parallel_mode_; } void set_sharding_propagation(const bool); bool sharding_propagation() const { return sharding_propagation_; } + void set_enable_all2all(const bool); + bool enable_all2all() const { return enable_all2all_; } void Reset(); void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph); @@ -165,6 +167,8 @@ class ParallelContext { // In AUTO_PARALLEL mode, 'sharding_propagation_' = True indicates that sharding-configured operators // will propagate the sharding strategies to other operators with minimum redistribution cost. bool sharding_propagation_; + // Enable AllToAll or not. If false, use AllGather and Split. + bool enable_all2all_; }; } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index a741deada9..fea1fba331 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -815,7 +815,7 @@ std::vector ReplaceOpInput(const Operator &replace_op, const std::st replace_input.push_back(node->input(i)); } } - + SetCommunicationOpGroupLabel(replace_input); return replace_input; } diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.cc b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.cc index ee9a95f69f..f1a8a77347 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.cc +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.cc @@ -19,6 +19,7 @@ #include #include "frontend/parallel/device_manager.h" +#include "frontend/parallel/context.h" namespace mindspore { namespace parallel { @@ -133,7 +134,7 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() { [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { int64_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim); int64_t dev_num = dev_mat_.GetDimByReverseIdx(LongToSize(out_dim)); - if (is_cost_model_) { + if (ParallelContext::GetInstance()->enable_all2all()) { int64_t dev_dim = in_tensor_map_.GetDimByIdx(LongToUlong(cat_dim)); Args args_alltoall = {dev_mat_.GetDimByReverseIdx(LongToUlong(dev_dim)), UlongToLong(index), cat_dim, dev_dim, dev_num}; diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 3f94066487..c3a98a42eb 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -189,6 +189,8 @@ PYBIND11_MODULE(_c_expression, m) { .def("set_sharding_propagation", &ParallelContext::set_sharding_propagation, "Set sharding strategy propagation value.") .def("get_sharding_propagation", &ParallelContext::sharding_propagation, "Get sharding strategy propagation value.") + .def("set_enable_alltoall", &ParallelContext::set_enable_all2all, "Set the enabling AllToAll value.") + .def("get_enable_alltoall", &ParallelContext::enable_all2all, "Get the enabling AllToAll value.") .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); (void)py::class_>(m, "CostModelContext") diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index f0e967b71c..8602cb15ce 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -463,6 +463,24 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_sharding_propagation() + def set_enable_alltoall(self, enable_a2a): + """ + Set the value of enabling AllToAll. If False, AllGather and Split are used to circumvent AllToAll. + Default: False. + + Args: + enable_a2a (bool): Enable/disable AllToAll. + """ + self.check_context_handle() + if not isinstance(enable_a2a, bool): + raise TypeError("'enable_a2a' is an invalid type.") + self._context_handle.set_enable_alltoall(enable_a2a) + + def get_enable_alltoall(self): + """Get the value of enabling AllToAll.""" + self.check_context_handle() + return self._context_handle.get_enable_alltoall() + def set_communi_parallel_mode(self, communi_parallel_mode): """ Set communication parallel mode. @@ -584,7 +602,8 @@ _set_auto_parallel_context_func_map = { "communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode, "optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size, "optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save, - "sharding_propagation": auto_parallel_context().set_sharding_propagation} + "sharding_propagation": auto_parallel_context().set_sharding_propagation, + "enable_alltoall": auto_parallel_context().set_enable_alltoall} _get_auto_parallel_context_func_map = { @@ -606,7 +625,8 @@ _get_auto_parallel_context_func_map = { "communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode, "optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size, "optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save, - "sharding_propagation": auto_parallel_context().get_sharding_propagation} + "sharding_propagation": auto_parallel_context().get_sharding_propagation, + "enable_alltoall": auto_parallel_context().get_enable_alltoall} @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, @@ -616,7 +636,7 @@ _get_auto_parallel_context_func_map = { grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str, communi_parallel_mode=str, optimizer_weight_shard_size=int, optimizer_weight_shard_aggregated_save=bool, - sharding_propagation=bool) + sharding_propagation=bool, enable_alltoall=bool) def _set_auto_parallel_context(**kwargs): """ @@ -682,6 +702,8 @@ def _set_auto_parallel_context(**kwargs): the strategy-configured operators will propagate the strategies to other operators with minimum redistribution cost; otherwise, the algorithm will search the desired strategies. Default: False. + enable_alltoall (bool): Set the value of enabling AllToAll. If False, AllGather and Split are used to + circumvent AllToAll. Default: False. Raises: ValueError: If input key is not attribute in auto parallel context. diff --git a/tests/ut/python/parallel/test_batchmm.py b/tests/ut/python/parallel/test_batchmm.py new file mode 100644 index 0000000000..99969e2d7a --- /dev/null +++ b/tests/ut/python/parallel/test_batchmm.py @@ -0,0 +1,89 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import mindspore as ms +import mindspore.context as context +from mindspore import Tensor, Parameter +import mindspore.nn as nn +from mindspore.common.api import _executor +from mindspore.nn import TrainOneStepCell, Momentum +from mindspore.ops import operations as P + +class Net(nn.Cell): + def __init__(self, wi, wo, stra1=None, stra2=None, stra3=None, stra4=None, + stra5=None, stra6=None): + super(Net, self).__init__() + self.relu = P.ReLU().shard(stra1) + self.transpose = P.Transpose().shard(stra2) + self.wi = Parameter(wi, "wi") + self.batch_mm = P.BatchMatMul().shard(stra3) + self.wo = Parameter(wo, "wo") + self.batch_mm2 = P.BatchMatMul().shard(stra4) + self.transpose2 = P.Transpose().shard(stra5) + self.relu2 = P.ReLU().shard(stra6) + self.reshape = P.Reshape() + self.reshape2 = P.Reshape() + + def construct(self, x): + output = self.relu(x) + trans_out = self.transpose(output, (2, 0, 3, 1)) + output = self.reshape(trans_out, + (trans_out.shape[0], trans_out.shape[1]*trans_out.shape[2], trans_out.shape[3])) + output = self.batch_mm(output, self.wi) + output = self.batch_mm2(output, self.wo) + output = self.reshape2(output, trans_out.shape) + output = self.transpose2(output, (1, 3, 0, 2)) + output = self.relu2(output) + return output + +_x = Tensor(np.ones([32, 16, 48, 128]), dtype=ms.float32) +_wi = Tensor(np.ones([48, 16, 64]), dtype=ms.float32) +_wo = Tensor(np.ones([48, 64, 16]), dtype=ms.float32) + + +def compile_net(net): + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + train_net.set_train() + _executor.compile(train_net, _x) + context.reset_auto_parallel_context() + + +def test_batchmm(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, enable_alltoall=True, + global_rank=0) + stra1 = ((8, 1, 1, 1),) + stra2 = ((8, 1, 1, 1),) + stra3 = ((8, 1, 1), (8, 1, 1)) + stra4 = ((8, 1, 1), (8, 1, 1)) + stra5 = ((8, 1, 1, 1),) + stra6 = ((8, 1, 1, 1),) + net = Net(_wi, _wo, stra1=stra1, stra2=stra2, stra3=stra3, stra4=stra4, stra5=stra5, stra6=stra6) + compile_net(net) + + +def test_batchmm2(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", enable_alltoall=True, + device_num=32, global_rank=0) + stra1 = ((4, 1, 1, 1),) + stra2 = ((4, 1, 1, 1),) + stra3 = ((4, 1, 1), (4, 1, 8)) + stra4 = ((4, 1, 8), (4, 8, 1)) + stra5 = ((4, 1, 1, 1),) + stra6 = ((4, 1, 1, 1),) + net = Net(_wi, _wo, stra1=stra1, stra2=stra2, stra3=stra3, stra4=stra4, stra5=stra5, stra6=stra6) + compile_net(net)