|
|
|
@@ -438,12 +438,9 @@ std::vector<int32_t> GetRankFromGroup(const Group &group) { |
|
|
|
|
|
|
|
Status GatherV2PInfo::InferForwardCommunication() { |
|
|
|
forward_op_.clear(); |
|
|
|
if (target_ != CPU) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
auto param_strategy = strategy_->GetInputDim().at(0); |
|
|
|
// don't split axis, no need forward communication |
|
|
|
if (param_strategy.at(IntToSize(axis_)) == 1) { |
|
|
|
// don't split axis or target is not CPU, no need forward communication |
|
|
|
if (target_ != CPU || param_strategy.at(IntToSize(axis_)) == 1) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
// split axis |
|
|
|
|