| @@ -2313,6 +2313,23 @@ std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const | |||||
| return bucket_size_list; | return bucket_size_list; | ||||
| } | } | ||||
| void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split_index) { | |||||
| if (split_index->empty()) { | |||||
| return; | |||||
| } | |||||
| // calculate split index num | |||||
| auto split_index_len = split_index->size(); | |||||
| uint32_t split_index_num = (*split_index)[split_index_len - 1]; | |||||
| // obtain graph output tensor num | |||||
| auto grads_count = GetBpropGraphGradsCount(graph); | |||||
| if (split_index_num == 0 || split_index_num >= grads_count) { | |||||
| MS_LOG(EXCEPTION) << "invalid AllReduce split index " << split_index_num << " and grads count " << grads_count; | |||||
| } | |||||
| if (split_index_num < grads_count - 1) { | |||||
| split_index->push_back(grads_count - 1); | |||||
| } | |||||
| } | |||||
| void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) { | void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id(); | MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id(); | ||||
| @@ -2325,6 +2342,7 @@ void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) { | |||||
| std::vector<std::shared_ptr<device::Bucket>> bucket_list; | std::vector<std::shared_ptr<device::Bucket>> bucket_list; | ||||
| // Create bucket for every split allreduce ops | // Create bucket for every split allreduce ops | ||||
| auto split_index = GetAllReduceSplitIndex(); | auto split_index = GetAllReduceSplitIndex(); | ||||
| PreProcessOnSplitIndex(graph, &split_index); | |||||
| auto bucket_size_list = GenerateBucketSizeList(graph, split_index); | auto bucket_size_list = GenerateBucketSizeList(graph, split_index); | ||||
| uint32_t bucket_id = 0; | uint32_t bucket_id = 0; | ||||
| for (auto bucket_size : bucket_size_list) { | for (auto bucket_size : bucket_size_list) { | ||||