Browse Source

!14855 fix a allreduce bug in pynative mode

From: @lvchangquan
Reviewed-by: @zhoufeng54,@jjfeing
Signed-off-by: @jjfeing
tags/v1.3.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
109c29c546
1 changed files with 10 additions and 0 deletions
  1. +10
    -0
      mindspore/ccsrc/backend/session/session_basic.cc

+ 10
- 0
mindspore/ccsrc/backend/session/session_basic.cc View File

@@ -2294,6 +2294,7 @@ std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const
}

void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split_index) {
MS_EXCEPTION_IF_NULL(split_index);
if (split_index->empty()) {
return;
}
@@ -2313,6 +2314,15 @@ void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split
void SessionBasic::InitAllBucket(const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
auto parallel_context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context);
auto parallel_mode = parallel_context->parallel_mode();
if (!pynative_mode || parallel_mode != parallel::DATA_PARALLEL) {
return;
}
SetGraphBpropAttr(graph);

if (!graph->is_bprop()) {


Loading…
Cancel
Save