From 91b7c4a9b6e713e1144814b43cb00ee096ebc090 Mon Sep 17 00:00:00 2001 From: yangzhenzhang Date: Tue, 27 Apr 2021 21:03:18 +0800 Subject: [PATCH] fix gather bug --- .../parallel/ops_info/gather_v2_p_info.cc | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc index ec56ce1071..c8abaf00eb 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc @@ -313,6 +313,17 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) { return FAILED; } + if (ShardBatchAndAxis(strategy->GetInputDim())) { + shard_batch_and_axis_ = true; + axis_split_forward_allreduce_ = true; + MS_LOG(INFO) << name_ << ": Sharding batch and axis, and the forward use allreduce"; + return SUCCESS; + } else if (is_auto_parallel_) { + // in auto parallel mode, this function will be called many times, so need to reset the flags + shard_batch_and_axis_ = false; + axis_split_forward_allreduce_ = false; + } + // axis=0, index_shape(0)%param_strategy(0) must be 0 Shape index_shape = inputs_shape_.at(1); if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) { @@ -331,17 +342,6 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) { return SUCCESS; } - if (ShardBatchAndAxis(strategy->GetInputDim())) { - shard_batch_and_axis_ = true; - axis_split_forward_allreduce_ = true; - MS_LOG(INFO) << name_ << ": Sharding batch and axis, and the forward use allreduce"; - return SUCCESS; - } else if (is_auto_parallel_) { - // in auto parallel mode, this function will be called many times, so need to reset the flags - shard_batch_and_axis_ = false; - axis_split_forward_allreduce_ = false; - } - // axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0 if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(LongToSize(axis_))) != 0) { MS_LOG(DEBUG) << name_ << ": param_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";