|
|
|
@@ -2237,10 +2237,12 @@ void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split |
|
|
|
uint32_t split_index_num = (*split_index)[split_index_len - 1]; |
|
|
|
// obtain graph output tensor num |
|
|
|
auto grads_count = GetBpropGraphGradsCount(graph); |
|
|
|
if (split_index_num == 0 || split_index_num >= grads_count) { |
|
|
|
MS_LOG(EXCEPTION) << "invalid AllReduce split index " << split_index_num << " and grads count " << grads_count; |
|
|
|
} |
|
|
|
if (split_index_num < grads_count - 1) { |
|
|
|
if (split_index_num >= grads_count) { |
|
|
|
MS_LOG(WARNING) << "Invalid all_reduce_fusion_config:" << *split_index << " total grads count:" << grads_count |
|
|
|
<< ". AllReduces are fused into one AllReduce."; |
|
|
|
split_index->clear(); |
|
|
|
split_index->push_back(grads_count); |
|
|
|
} else if (split_index_num < grads_count - 1) { |
|
|
|
split_index->push_back(grads_count - 1); |
|
|
|
} |
|
|
|
} |
|
|
|
|