Browse Source

Add PyNative parallel context checker

tags/v1.2.0-rc1
caifubi 5 years ago
parent
commit
7366fcf761
1 changed files with 11 additions and 0 deletions
  1. +11
    -0
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc

+ 11
- 0
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -54,6 +54,7 @@
#include "pipeline/jit/resource.h" #include "pipeline/jit/resource.h"
#include "pipeline/jit/pipeline.h" #include "pipeline/jit/pipeline.h"
#include "pipeline/jit/pass.h" #include "pipeline/jit/pass.h"
#include "frontend/parallel/context.h"


#ifdef ENABLE_GE #ifdef ENABLE_GE
#include "pipeline/pynative/pynative_execute_ge.h" #include "pipeline/pynative/pynative_execute_ge.h"
@@ -523,7 +524,17 @@ PynativeExecutor::~PynativeExecutor() {
ClearRes(); ClearRes();
} }


void CheckPyNativeContext() {
auto context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
auto parallel_mode = context->parallel_mode();
if (parallel_mode != parallel::STAND_ALONE && parallel_mode != parallel::DATA_PARALLEL) {
MS_LOG(EXCEPTION) << "PyNative Only support STAND_ALONE and DATA_PARALLEL, but got:" << parallel_mode;
}
}

py::object RunOp(const py::args &args) { py::object RunOp(const py::args &args) {
CheckPyNativeContext();
auto executor = PynativeExecutor::GetInstance(); auto executor = PynativeExecutor::GetInstance();
MS_EXCEPTION_IF_NULL(executor); MS_EXCEPTION_IF_NULL(executor);
OpExecInfoPtr op_exec_info = executor->GenerateOpExecInfo(args); OpExecInfoPtr op_exec_info = executor->GenerateOpExecInfo(args);


Loading…
Cancel
Save