Browse Source

!15792 fix gather bug

From: @yangzhenzhang
Reviewed-by: @kisnwang,@stsuteng
Signed-off-by: @stsuteng
pull/15792/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
117beb7957
1 changed files with 11 additions and 11 deletions
  1. +11
    -11
      mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc

+ 11
- 11
mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc View File

@@ -313,6 +313,17 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED;
}

if (ShardBatchAndAxis(strategy->GetInputDim())) {
shard_batch_and_axis_ = true;
axis_split_forward_allreduce_ = true;
MS_LOG(INFO) << name_ << ": Sharding batch and axis, and the forward use allreduce";
return SUCCESS;
} else if (is_auto_parallel_) {
// in auto parallel mode, this function will be called many times, so need to reset the flags
shard_batch_and_axis_ = false;
axis_split_forward_allreduce_ = false;
}

// axis=0, index_shape(0)%param_strategy(0) must be 0
Shape index_shape = inputs_shape_.at(1);
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) {
@@ -331,17 +342,6 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) {
return SUCCESS;
}

if (ShardBatchAndAxis(strategy->GetInputDim())) {
shard_batch_and_axis_ = true;
axis_split_forward_allreduce_ = true;
MS_LOG(INFO) << name_ << ": Sharding batch and axis, and the forward use allreduce";
return SUCCESS;
} else if (is_auto_parallel_) {
// in auto parallel mode, this function will be called many times, so need to reset the flags
shard_batch_and_axis_ = false;
axis_split_forward_allreduce_ = false;
}

// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(LongToSize(axis_))) != 0) {
MS_LOG(DEBUG) << name_ << ": param_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis)).";


Loading…
Cancel
Save