Merge pull request !20712 from Xiaoda/76-change-alltoall-redistributiontags/v1.4.0
| @@ -72,6 +72,7 @@ void ParallelContext::Reset() { | |||
| optimizer_weight_shard_size_ = -1; | |||
| optimizer_weight_shard_aggregated_save_ = false; | |||
| sharding_propagation_ = false; | |||
| enable_all2all_ = false; | |||
| } | |||
| void ParallelContext::set_device_num(int64_t device_num) { | |||
| @@ -269,5 +270,7 @@ void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func | |||
| } | |||
| void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; } | |||
| void ParallelContext::set_enable_all2all(const bool enable) { enable_all2all_ = enable; } | |||
| } // namespace parallel | |||
| } // namespace mindspore | |||
| @@ -126,6 +126,8 @@ class ParallelContext { | |||
| std::string communi_parallel_mode() const { return communi_parallel_mode_; } | |||
| void set_sharding_propagation(const bool); | |||
| bool sharding_propagation() const { return sharding_propagation_; } | |||
| void set_enable_all2all(const bool); | |||
| bool enable_all2all() const { return enable_all2all_; } | |||
| void Reset(); | |||
| void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph); | |||
| @@ -165,6 +167,8 @@ class ParallelContext { | |||
| // In AUTO_PARALLEL mode, 'sharding_propagation_' = True indicates that sharding-configured operators | |||
| // will propagate the sharding strategies to other operators with minimum redistribution cost. | |||
| bool sharding_propagation_; | |||
| // Enable AllToAll or not. If false, use AllGather and Split. | |||
| bool enable_all2all_; | |||
| }; | |||
| } // namespace parallel | |||
| @@ -815,7 +815,7 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st | |||
| replace_input.push_back(node->input(i)); | |||
| } | |||
| } | |||
| SetCommunicationOpGroupLabel(replace_input); | |||
| return replace_input; | |||
| } | |||
| @@ -19,6 +19,7 @@ | |||
| #include <utility> | |||
| #include "frontend/parallel/device_manager.h" | |||
| #include "frontend/parallel/context.h" | |||
| namespace mindspore { | |||
| namespace parallel { | |||
| @@ -133,7 +134,7 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() { | |||
| [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { | |||
| int64_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim); | |||
| int64_t dev_num = dev_mat_.GetDimByReverseIdx(LongToSize(out_dim)); | |||
| if (is_cost_model_) { | |||
| if (ParallelContext::GetInstance()->enable_all2all()) { | |||
| int64_t dev_dim = in_tensor_map_.GetDimByIdx(LongToUlong(cat_dim)); | |||
| Args args_alltoall = {dev_mat_.GetDimByReverseIdx(LongToUlong(dev_dim)), UlongToLong(index), cat_dim, dev_dim, | |||
| dev_num}; | |||
| @@ -189,6 +189,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| .def("set_sharding_propagation", &ParallelContext::set_sharding_propagation, | |||
| "Set sharding strategy propagation value.") | |||
| .def("get_sharding_propagation", &ParallelContext::sharding_propagation, "Get sharding strategy propagation value.") | |||
| .def("set_enable_alltoall", &ParallelContext::set_enable_all2all, "Set the enabling AllToAll value.") | |||
| .def("get_enable_alltoall", &ParallelContext::enable_all2all, "Get the enabling AllToAll value.") | |||
| .def("reset", &ParallelContext::Reset, "Reset auto parallel context."); | |||
| (void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext") | |||
| @@ -463,6 +463,24 @@ class _AutoParallelContext: | |||
| self.check_context_handle() | |||
| return self._context_handle.get_sharding_propagation() | |||
| def set_enable_alltoall(self, enable_a2a): | |||
| """ | |||
| Set the value of enabling AllToAll. If False, AllGather and Split are used to circumvent AllToAll. | |||
| Default: False. | |||
| Args: | |||
| enable_a2a (bool): Enable/disable AllToAll. | |||
| """ | |||
| self.check_context_handle() | |||
| if not isinstance(enable_a2a, bool): | |||
| raise TypeError("'enable_a2a' is an invalid type.") | |||
| self._context_handle.set_enable_alltoall(enable_a2a) | |||
| def get_enable_alltoall(self): | |||
| """Get the value of enabling AllToAll.""" | |||
| self.check_context_handle() | |||
| return self._context_handle.get_enable_alltoall() | |||
| def set_communi_parallel_mode(self, communi_parallel_mode): | |||
| """ | |||
| Set communication parallel mode. | |||
| @@ -584,7 +602,8 @@ _set_auto_parallel_context_func_map = { | |||
| "communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode, | |||
| "optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size, | |||
| "optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save, | |||
| "sharding_propagation": auto_parallel_context().set_sharding_propagation} | |||
| "sharding_propagation": auto_parallel_context().set_sharding_propagation, | |||
| "enable_alltoall": auto_parallel_context().set_enable_alltoall} | |||
| _get_auto_parallel_context_func_map = { | |||
| @@ -606,7 +625,8 @@ _get_auto_parallel_context_func_map = { | |||
| "communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode, | |||
| "optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size, | |||
| "optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save, | |||
| "sharding_propagation": auto_parallel_context().get_sharding_propagation} | |||
| "sharding_propagation": auto_parallel_context().get_sharding_propagation, | |||
| "enable_alltoall": auto_parallel_context().get_enable_alltoall} | |||
| @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, | |||
| @@ -616,7 +636,7 @@ _get_auto_parallel_context_func_map = { | |||
| grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str, | |||
| communi_parallel_mode=str, optimizer_weight_shard_size=int, | |||
| optimizer_weight_shard_aggregated_save=bool, | |||
| sharding_propagation=bool) | |||
| sharding_propagation=bool, enable_alltoall=bool) | |||
| def _set_auto_parallel_context(**kwargs): | |||
| """ | |||
| @@ -682,6 +702,8 @@ def _set_auto_parallel_context(**kwargs): | |||
| the strategy-configured operators will propagate the strategies to other | |||
| operators with minimum redistribution cost; otherwise, the algorithm will | |||
| search the desired strategies. Default: False. | |||
| enable_alltoall (bool): Set the value of enabling AllToAll. If False, AllGather and Split are used to | |||
| circumvent AllToAll. Default: False. | |||
| Raises: | |||
| ValueError: If input key is not attribute in auto parallel context. | |||
| @@ -0,0 +1,89 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.context as context | |||
| from mindspore import Tensor, Parameter | |||
| import mindspore.nn as nn | |||
| from mindspore.common.api import _executor | |||
| from mindspore.nn import TrainOneStepCell, Momentum | |||
| from mindspore.ops import operations as P | |||
| class Net(nn.Cell): | |||
| def __init__(self, wi, wo, stra1=None, stra2=None, stra3=None, stra4=None, | |||
| stra5=None, stra6=None): | |||
| super(Net, self).__init__() | |||
| self.relu = P.ReLU().shard(stra1) | |||
| self.transpose = P.Transpose().shard(stra2) | |||
| self.wi = Parameter(wi, "wi") | |||
| self.batch_mm = P.BatchMatMul().shard(stra3) | |||
| self.wo = Parameter(wo, "wo") | |||
| self.batch_mm2 = P.BatchMatMul().shard(stra4) | |||
| self.transpose2 = P.Transpose().shard(stra5) | |||
| self.relu2 = P.ReLU().shard(stra6) | |||
| self.reshape = P.Reshape() | |||
| self.reshape2 = P.Reshape() | |||
| def construct(self, x): | |||
| output = self.relu(x) | |||
| trans_out = self.transpose(output, (2, 0, 3, 1)) | |||
| output = self.reshape(trans_out, | |||
| (trans_out.shape[0], trans_out.shape[1]*trans_out.shape[2], trans_out.shape[3])) | |||
| output = self.batch_mm(output, self.wi) | |||
| output = self.batch_mm2(output, self.wo) | |||
| output = self.reshape2(output, trans_out.shape) | |||
| output = self.transpose2(output, (1, 3, 0, 2)) | |||
| output = self.relu2(output) | |||
| return output | |||
| _x = Tensor(np.ones([32, 16, 48, 128]), dtype=ms.float32) | |||
| _wi = Tensor(np.ones([48, 16, 64]), dtype=ms.float32) | |||
| _wo = Tensor(np.ones([48, 64, 16]), dtype=ms.float32) | |||
| def compile_net(net): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||
| train_net = TrainOneStepCell(net, optimizer) | |||
| train_net.set_auto_parallel() | |||
| train_net.set_train() | |||
| _executor.compile(train_net, _x) | |||
| context.reset_auto_parallel_context() | |||
| def test_batchmm(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, enable_alltoall=True, | |||
| global_rank=0) | |||
| stra1 = ((8, 1, 1, 1),) | |||
| stra2 = ((8, 1, 1, 1),) | |||
| stra3 = ((8, 1, 1), (8, 1, 1)) | |||
| stra4 = ((8, 1, 1), (8, 1, 1)) | |||
| stra5 = ((8, 1, 1, 1),) | |||
| stra6 = ((8, 1, 1, 1),) | |||
| net = Net(_wi, _wo, stra1=stra1, stra2=stra2, stra3=stra3, stra4=stra4, stra5=stra5, stra6=stra6) | |||
| compile_net(net) | |||
| def test_batchmm2(): | |||
| context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", enable_alltoall=True, | |||
| device_num=32, global_rank=0) | |||
| stra1 = ((4, 1, 1, 1),) | |||
| stra2 = ((4, 1, 1, 1),) | |||
| stra3 = ((4, 1, 1), (4, 1, 8)) | |||
| stra4 = ((4, 1, 8), (4, 8, 1)) | |||
| stra5 = ((4, 1, 1, 1),) | |||
| stra6 = ((4, 1, 1, 1),) | |||
| net = Net(_wi, _wo, stra1=stra1, stra2=stra2, stra3=stra3, stra4=stra4, stra5=stra5, stra6=stra6) | |||
| compile_net(net) | |||