From 7b33f3e2ac544c4f53acd7dbfb158b1b5ee1cd3d Mon Sep 17 00:00:00 2001 From: yangzhenzhang Date: Mon, 30 Nov 2020 15:53:32 +0800 Subject: [PATCH] gatherv2 axis split repeated calculation --- .../parallel/ops_info/gather_v2_p_info.cc | 27 +++++++++++++++---- tests/ut/python/parallel/test_gather_v2.py | 13 +++++++++ 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc index e50e6c659b..bddf06b96c 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -208,7 +208,7 @@ Status GatherV2PInfo::CheckManualSplit(const Strategys &strategy) { int64_t split_shape_sum = std::accumulate(param_split_shapes_.begin(), param_split_shapes_.end(), 0, [](int64_t s, int64_t shape) { return s + shape; }); if (split_shape_sum != inputs_shape_[0][0]) { - MS_LOG(ERROR) << name_ << ": Sum of splited shapes must be equal to param_shape[0]"; + MS_LOG(ERROR) << name_ << ": Sum of split shapes must be equal to param_shape[0]"; return FAILED; } return SUCCESS; @@ -261,21 +261,34 @@ Status GatherV2PInfo::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } - // param_strategy(axis) != 1, index can't be splited + // param_strategy(axis) != 1, index can't be split auto index_strategy = strategy->GetInputDim().at(1); auto product_i = std::accumulate(index_strategy.begin(), index_strategy.end(), 1, std::multiplies()); if ((param_strategy.at(LongToSize(axis_)) != 1) && (product_i != 1)) { - MS_LOG(DEBUG) << name_ << ": param is splited at dim (axis)" << axis_ << " ,index can't be splited."; + MS_LOG(DEBUG) << name_ << ": param is split at dim (axis)" << axis_ << " ,index can't be split."; return FAILED; } - // param_strategy(axis) != 1, Don't support repeated calc + // param_strategy(axis) != 1, and axis != 0, don't support repeated calc auto product_p = std::accumulate(param_strategy.begin(), param_strategy.end(), 1, std::multiplies()); - if (product_p != stage_device_size_ && param_strategy.at(IntToSize(axis_)) != 1) { + if ((product_p != stage_device_size_) && (param_strategy.at(IntToSize(axis_)) != 1) && (axis_ != 0)) { MS_LOG(DEBUG) << name_ << ": Invalid strategy. Don't support repeated calc."; return FAILED; } + // param_strategy(axis) != 1, and axis == 0, and repeated calculation, need to set repeated num to the right + // of dev-matrix. For example, parameter strategy is [2, 1], indices strategy is [1, 1], dev num is 16, + // and dev_matrix is [2, 1, 1, 1, 8], the communication groups are [0, 8] and [0, 1, 2, 3, 4, 5, 6, 7], they + // can communicate normally. + if ((product_p != stage_device_size_) && (param_strategy.at(IntToSize(axis_)) != 1) && (axis_ == 0)) { + if ((param_strategy.size() == 2) && (param_strategy[1] != 1)) { + MS_LOG(DEBUG) << name_ << ": axis(0) is split, and param_strategy[1] != 1, don't support repeated calc."; + return FAILED; + } + MS_LOG(INFO) << name_ << ": split axis(0) and repeat calculation"; + repeated_num_in_dev_matrix_right_ = true; + } + return SUCCESS; } @@ -493,6 +506,10 @@ Status GatherV2PInfo::InferBias() { // params_size=2, axis=0 if ((input_shape.size() == 2) && (axis_ == 0)) { slice_size_ = input_shape.at(0) / params_strategy.at(0); + // if repeated calculation, because the repeated num in the right of dev-matrix, so rank need to div repeated num + if (repeated_calc_num_ > 1) { + rank = rank / repeated_calc_num_; + } bias_ = rank / params_strategy.at(1) * slice_size_; return SUCCESS; } diff --git a/tests/ut/python/parallel/test_gather_v2.py b/tests/ut/python/parallel/test_gather_v2.py index 1509036509..b24bcd6f35 100644 --- a/tests/ut/python/parallel/test_gather_v2.py +++ b/tests/ut/python/parallel/test_gather_v2.py @@ -190,6 +190,19 @@ def test_gatherv2_forward_all_reduce(): _executor.compile(net, x, y) +def test_gatherv2_split_axis_0_repeat_calc(): + context.set_auto_parallel_context(device_num=8, global_rank=7, parallel_mode="semi_auto_parallel") + strategy1 = ((4, 1), (1, 1)) + strategy2 = ((2, 4, 1), (2, 4, 1)) + net = GradWrap(NetWithLoss(Net(0, strategy1, strategy2, shape=[2, 64]))) + net.set_auto_parallel() + + x = Tensor(np.ones([64, 64]), dtype=ms.float32) + y = Tensor(np.ones([2, 64, 64]), dtype=ms.float32) + net.set_train() + _executor.compile(net, x, y) + + def test_gatherv2_auto0(): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel") net = GradWrap(NetWithLoss(Net(0)))