Browse Source

fix a split index bug in launch allreduce

pull/14686/head
lvchangquan 4 years ago
parent
commit
7fa0e12223
1 changed files with 18 additions and 0 deletions
  1. +18
    -0
      mindspore/ccsrc/backend/session/session_basic.cc

+ 18
- 0
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -2313,6 +2313,23 @@ std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const
return bucket_size_list;
}

void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split_index) {
if (split_index->empty()) {
return;
}
// calculate split index num
auto split_index_len = split_index->size();
uint32_t split_index_num = (*split_index)[split_index_len - 1];
// obtain graph output tensor num
auto grads_count = GetBpropGraphGradsCount(graph);
if (split_index_num == 0 || split_index_num >= grads_count) {
MS_LOG(EXCEPTION) << "invalid AllReduce split index " << split_index_num << " and grads count " << grads_count;
}
if (split_index_num < grads_count - 1) {
split_index->push_back(grads_count - 1);
}
}

void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id();
@@ -2325,6 +2342,7 @@ void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) {
std::vector<std::shared_ptr<device::Bucket>> bucket_list;
// Create bucket for every split allreduce ops
auto split_index = GetAllReduceSplitIndex();
PreProcessOnSplitIndex(graph, &split_index);
auto bucket_size_list = GenerateBucketSizeList(graph, split_index);
uint32_t bucket_id = 0;
for (auto bucket_size : bucket_size_list) {


Loading…
Cancel
Save