|
|
|
@@ -313,6 +313,17 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (ShardBatchAndAxis(strategy->GetInputDim())) { |
|
|
|
shard_batch_and_axis_ = true; |
|
|
|
axis_split_forward_allreduce_ = true; |
|
|
|
MS_LOG(INFO) << name_ << ": Sharding batch and axis, and the forward use allreduce"; |
|
|
|
return SUCCESS; |
|
|
|
} else if (is_auto_parallel_) { |
|
|
|
// in auto parallel mode, this function will be called many times, so need to reset the flags |
|
|
|
shard_batch_and_axis_ = false; |
|
|
|
axis_split_forward_allreduce_ = false; |
|
|
|
} |
|
|
|
|
|
|
|
// axis=0, index_shape(0)%param_strategy(0) must be 0 |
|
|
|
Shape index_shape = inputs_shape_.at(1); |
|
|
|
if ((axis_ == 0) && (index_shape.at(0) % param_strategy.at(0) != 0) && !dynamic_shape_indices_) { |
|
|
|
@@ -331,17 +342,6 @@ Status GatherPInfo::CheckStrategy(const StrategyPtr &strategy) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
if (ShardBatchAndAxis(strategy->GetInputDim())) { |
|
|
|
shard_batch_and_axis_ = true; |
|
|
|
axis_split_forward_allreduce_ = true; |
|
|
|
MS_LOG(INFO) << name_ << ": Sharding batch and axis, and the forward use allreduce"; |
|
|
|
return SUCCESS; |
|
|
|
} else if (is_auto_parallel_) { |
|
|
|
// in auto parallel mode, this function will be called many times, so need to reset the flags |
|
|
|
shard_batch_and_axis_ = false; |
|
|
|
axis_split_forward_allreduce_ = false; |
|
|
|
} |
|
|
|
|
|
|
|
// axis != 0, param_shape(0)%(param_strategy(0)*param_strategy(axis)) must be 0 |
|
|
|
if (axis_ != 0 && param_shape.at(0) % (param_strategy.at(0) * param_strategy.at(LongToSize(axis_))) != 0) { |
|
|
|
MS_LOG(DEBUG) << name_ << ": param_shape(0) can't be divided by (param_strategy(0)*param_strategy(axis))."; |
|
|
|
|