|
|
|
@@ -54,6 +54,7 @@ |
|
|
|
#include "pipeline/jit/resource.h" |
|
|
|
#include "pipeline/jit/pipeline.h" |
|
|
|
#include "pipeline/jit/pass.h" |
|
|
|
#include "frontend/parallel/context.h" |
|
|
|
|
|
|
|
#ifdef ENABLE_GE |
|
|
|
#include "pipeline/pynative/pynative_execute_ge.h" |
|
|
|
@@ -523,7 +524,17 @@ PynativeExecutor::~PynativeExecutor() { |
|
|
|
ClearRes(); |
|
|
|
} |
|
|
|
|
|
|
|
void CheckPyNativeContext() { |
|
|
|
auto context = parallel::ParallelContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context); |
|
|
|
auto parallel_mode = context->parallel_mode(); |
|
|
|
if (parallel_mode != parallel::STAND_ALONE && parallel_mode != parallel::DATA_PARALLEL) { |
|
|
|
MS_LOG(EXCEPTION) << "PyNative Only support STAND_ALONE and DATA_PARALLEL, but got:" << parallel_mode; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
py::object RunOp(const py::args &args) { |
|
|
|
CheckPyNativeContext(); |
|
|
|
auto executor = PynativeExecutor::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(executor); |
|
|
|
OpExecInfoPtr op_exec_info = executor->GenerateOpExecInfo(args); |
|
|
|
|