|
|
|
@@ -2313,6 +2313,23 @@ std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const |
|
|
|
return bucket_size_list; |
|
|
|
} |
|
|
|
|
|
|
|
void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split_index) { |
|
|
|
if (split_index->empty()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
// calculate split index num |
|
|
|
auto split_index_len = split_index->size(); |
|
|
|
uint32_t split_index_num = (*split_index)[split_index_len - 1]; |
|
|
|
// obtain graph output tensor num |
|
|
|
auto grads_count = GetBpropGraphGradsCount(graph); |
|
|
|
if (split_index_num == 0 || split_index_num >= grads_count) { |
|
|
|
MS_LOG(EXCEPTION) << "invalid AllReduce split index " << split_index_num << " and grads count " << grads_count; |
|
|
|
} |
|
|
|
if (split_index_num < grads_count - 1) { |
|
|
|
split_index->push_back(grads_count - 1); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id(); |
|
|
|
@@ -2325,6 +2342,7 @@ void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) { |
|
|
|
std::vector<std::shared_ptr<device::Bucket>> bucket_list; |
|
|
|
// Create bucket for every split allreduce ops |
|
|
|
auto split_index = GetAllReduceSplitIndex(); |
|
|
|
PreProcessOnSplitIndex(graph, &split_index); |
|
|
|
auto bucket_size_list = GenerateBucketSizeList(graph, split_index); |
|
|
|
uint32_t bucket_id = 0; |
|
|
|
for (auto bucket_size : bucket_size_list) { |
|
|
|
|