diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 8c2acc2a43..41c243d03d 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -2313,6 +2313,23 @@ std::vector GenerateBucketSizeList(const KernelGraphPtr &graph, const return bucket_size_list; } +void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector *split_index) { + if (split_index->empty()) { + return; + } + // calculate split index num + auto split_index_len = split_index->size(); + uint32_t split_index_num = (*split_index)[split_index_len - 1]; + // obtain graph output tensor num + auto grads_count = GetBpropGraphGradsCount(graph); + if (split_index_num == 0 || split_index_num >= grads_count) { + MS_LOG(EXCEPTION) << "invalid AllReduce split index " << split_index_num << " and grads count " << grads_count; + } + if (split_index_num < grads_count - 1) { + split_index->push_back(grads_count - 1); + } +} + void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id(); @@ -2325,6 +2342,7 @@ void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) { std::vector> bucket_list; // Create bucket for every split allreduce ops auto split_index = GetAllReduceSplitIndex(); + PreProcessOnSplitIndex(graph, &split_index); auto bucket_size_list = GenerateBucketSizeList(graph, split_index); uint32_t bucket_id = 0; for (auto bucket_size : bucket_size_list) {