Browse Source

!20712 [Auto parallel] Enable All2All when inferring redistribution ops

Merge pull request !20712 from Xiaoda/76-change-alltoall-redistribution
tags/v1.4.0
i-robot Gitee 4 years ago
parent
commit
c80098b261
7 changed files with 126 additions and 5 deletions
  1. +3
    -0
      mindspore/ccsrc/frontend/parallel/context.cc
  2. +4
    -0
      mindspore/ccsrc/frontend/parallel/context.h
  3. +1
    -1
      mindspore/ccsrc/frontend/parallel/step_parallel.cc
  4. +2
    -1
      mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.cc
  5. +2
    -0
      mindspore/ccsrc/pipeline/jit/init.cc
  6. +25
    -3
      mindspore/parallel/_auto_parallel_context.py
  7. +89
    -0
      tests/ut/python/parallel/test_batchmm.py

+ 3
- 0
mindspore/ccsrc/frontend/parallel/context.cc View File

@@ -72,6 +72,7 @@ void ParallelContext::Reset() {
optimizer_weight_shard_size_ = -1;
optimizer_weight_shard_aggregated_save_ = false;
sharding_propagation_ = false;
enable_all2all_ = false;
}

void ParallelContext::set_device_num(int64_t device_num) {
@@ -269,5 +270,7 @@ void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func
}

void ParallelContext::set_sharding_propagation(const bool stra_pto) { sharding_propagation_ = stra_pto; }

void ParallelContext::set_enable_all2all(const bool enable) { enable_all2all_ = enable; }
} // namespace parallel
} // namespace mindspore

+ 4
- 0
mindspore/ccsrc/frontend/parallel/context.h View File

@@ -126,6 +126,8 @@ class ParallelContext {
std::string communi_parallel_mode() const { return communi_parallel_mode_; }
void set_sharding_propagation(const bool);
bool sharding_propagation() const { return sharding_propagation_; }
void set_enable_all2all(const bool);
bool enable_all2all() const { return enable_all2all_; }

void Reset();
void ParallelParameterContextInitShape(const FuncGraphPtr &func_graph);
@@ -165,6 +167,8 @@ class ParallelContext {
// In AUTO_PARALLEL mode, 'sharding_propagation_' = True indicates that sharding-configured operators
// will propagate the sharding strategies to other operators with minimum redistribution cost.
bool sharding_propagation_;
// Enable AllToAll or not. If false, use AllGather and Split.
bool enable_all2all_;
};

} // namespace parallel


+ 1
- 1
mindspore/ccsrc/frontend/parallel/step_parallel.cc View File

@@ -815,7 +815,7 @@ std::vector<AnfNodePtr> ReplaceOpInput(const Operator &replace_op, const std::st
replace_input.push_back(node->input(i));
}
}
SetCommunicationOpGroupLabel(replace_input);
return replace_input;
}



+ 2
- 1
mindspore/ccsrc/frontend/parallel/tensor_layout/redistribution_operator_infer.cc View File

@@ -19,6 +19,7 @@
#include <utility>

#include "frontend/parallel/device_manager.h"
#include "frontend/parallel/context.h"

namespace mindspore {
namespace parallel {
@@ -133,7 +134,7 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() {
[out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) {
int64_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim);
int64_t dev_num = dev_mat_.GetDimByReverseIdx(LongToSize(out_dim));
if (is_cost_model_) {
if (ParallelContext::GetInstance()->enable_all2all()) {
int64_t dev_dim = in_tensor_map_.GetDimByIdx(LongToUlong(cat_dim));
Args args_alltoall = {dev_mat_.GetDimByReverseIdx(LongToUlong(dev_dim)), UlongToLong(index), cat_dim, dev_dim,
dev_num};


+ 2
- 0
mindspore/ccsrc/pipeline/jit/init.cc View File

@@ -189,6 +189,8 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_sharding_propagation", &ParallelContext::set_sharding_propagation,
"Set sharding strategy propagation value.")
.def("get_sharding_propagation", &ParallelContext::sharding_propagation, "Get sharding strategy propagation value.")
.def("set_enable_alltoall", &ParallelContext::set_enable_all2all, "Set the enabling AllToAll value.")
.def("get_enable_alltoall", &ParallelContext::enable_all2all, "Get the enabling AllToAll value.")
.def("reset", &ParallelContext::Reset, "Reset auto parallel context.");

(void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")


+ 25
- 3
mindspore/parallel/_auto_parallel_context.py View File

@@ -463,6 +463,24 @@ class _AutoParallelContext:
self.check_context_handle()
return self._context_handle.get_sharding_propagation()

def set_enable_alltoall(self, enable_a2a):
"""
Set the value of enabling AllToAll. If False, AllGather and Split are used to circumvent AllToAll.
Default: False.

Args:
enable_a2a (bool): Enable/disable AllToAll.
"""
self.check_context_handle()
if not isinstance(enable_a2a, bool):
raise TypeError("'enable_a2a' is an invalid type.")
self._context_handle.set_enable_alltoall(enable_a2a)

def get_enable_alltoall(self):
"""Get the value of enabling AllToAll."""
self.check_context_handle()
return self._context_handle.get_enable_alltoall()

def set_communi_parallel_mode(self, communi_parallel_mode):
"""
Set communication parallel mode.
@@ -584,7 +602,8 @@ _set_auto_parallel_context_func_map = {
"communi_parallel_mode": auto_parallel_context().set_communi_parallel_mode,
"optimizer_weight_shard_size": auto_parallel_context().set_optimizer_weight_shard_size,
"optimizer_weight_shard_aggregated_save": auto_parallel_context().set_optimizer_weight_shard_aggregated_save,
"sharding_propagation": auto_parallel_context().set_sharding_propagation}
"sharding_propagation": auto_parallel_context().set_sharding_propagation,
"enable_alltoall": auto_parallel_context().set_enable_alltoall}


_get_auto_parallel_context_func_map = {
@@ -606,7 +625,8 @@ _get_auto_parallel_context_func_map = {
"communi_parallel_mode": auto_parallel_context().get_communi_parallel_mode,
"optimizer_weight_shard_size": auto_parallel_context().get_optimizer_weight_shard_size,
"optimizer_weight_shard_aggregated_save": auto_parallel_context().get_optimizer_weight_shard_aggregated_save,
"sharding_propagation": auto_parallel_context().get_sharding_propagation}
"sharding_propagation": auto_parallel_context().get_sharding_propagation,
"enable_alltoall": auto_parallel_context().get_enable_alltoall}


@args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool,
@@ -616,7 +636,7 @@ _get_auto_parallel_context_func_map = {
grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str,
communi_parallel_mode=str, optimizer_weight_shard_size=int,
optimizer_weight_shard_aggregated_save=bool,
sharding_propagation=bool)
sharding_propagation=bool, enable_alltoall=bool)

def _set_auto_parallel_context(**kwargs):
"""
@@ -682,6 +702,8 @@ def _set_auto_parallel_context(**kwargs):
the strategy-configured operators will propagate the strategies to other
operators with minimum redistribution cost; otherwise, the algorithm will
search the desired strategies. Default: False.
enable_alltoall (bool): Set the value of enabling AllToAll. If False, AllGather and Split are used to
circumvent AllToAll. Default: False.

Raises:
ValueError: If input key is not attribute in auto parallel context.


+ 89
- 0
tests/ut/python/parallel/test_batchmm.py View File

@@ -0,0 +1,89 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import mindspore as ms
import mindspore.context as context
from mindspore import Tensor, Parameter
import mindspore.nn as nn
from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, Momentum
from mindspore.ops import operations as P

class Net(nn.Cell):
def __init__(self, wi, wo, stra1=None, stra2=None, stra3=None, stra4=None,
stra5=None, stra6=None):
super(Net, self).__init__()
self.relu = P.ReLU().shard(stra1)
self.transpose = P.Transpose().shard(stra2)
self.wi = Parameter(wi, "wi")
self.batch_mm = P.BatchMatMul().shard(stra3)
self.wo = Parameter(wo, "wo")
self.batch_mm2 = P.BatchMatMul().shard(stra4)
self.transpose2 = P.Transpose().shard(stra5)
self.relu2 = P.ReLU().shard(stra6)
self.reshape = P.Reshape()
self.reshape2 = P.Reshape()

def construct(self, x):
output = self.relu(x)
trans_out = self.transpose(output, (2, 0, 3, 1))
output = self.reshape(trans_out,
(trans_out.shape[0], trans_out.shape[1]*trans_out.shape[2], trans_out.shape[3]))
output = self.batch_mm(output, self.wi)
output = self.batch_mm2(output, self.wo)
output = self.reshape2(output, trans_out.shape)
output = self.transpose2(output, (1, 3, 0, 2))
output = self.relu2(output)
return output

_x = Tensor(np.ones([32, 16, 48, 128]), dtype=ms.float32)
_wi = Tensor(np.ones([48, 16, 64]), dtype=ms.float32)
_wo = Tensor(np.ones([48, 64, 16]), dtype=ms.float32)


def compile_net(net):
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, _x)
context.reset_auto_parallel_context()


def test_batchmm():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, enable_alltoall=True,
global_rank=0)
stra1 = ((8, 1, 1, 1),)
stra2 = ((8, 1, 1, 1),)
stra3 = ((8, 1, 1), (8, 1, 1))
stra4 = ((8, 1, 1), (8, 1, 1))
stra5 = ((8, 1, 1, 1),)
stra6 = ((8, 1, 1, 1),)
net = Net(_wi, _wo, stra1=stra1, stra2=stra2, stra3=stra3, stra4=stra4, stra5=stra5, stra6=stra6)
compile_net(net)


def test_batchmm2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", enable_alltoall=True,
device_num=32, global_rank=0)
stra1 = ((4, 1, 1, 1),)
stra2 = ((4, 1, 1, 1),)
stra3 = ((4, 1, 1), (4, 1, 8))
stra4 = ((4, 1, 8), (4, 8, 1))
stra5 = ((4, 1, 1, 1),)
stra6 = ((4, 1, 1, 1),)
net = Net(_wi, _wo, stra1=stra1, stra2=stra2, stra3=stra3, stra4=stra4, stra5=stra5, stra6=stra6)
compile_net(net)

Loading…
Cancel
Save