|
|
|
@@ -282,7 +282,7 @@ bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPa |
|
|
|
|
|
|
|
bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); } |
|
|
|
|
|
|
|
static bool IsCtrlSink(const FuncGraphPtr &graph) { |
|
|
|
static bool IsCtrlSink() { |
|
|
|
auto ms_ctx = MsContext::GetInstance(); |
|
|
|
if (ms_ctx->execution_mode() != kGraphMode) { |
|
|
|
return false; |
|
|
|
@@ -297,10 +297,9 @@ static bool IsCtrlSink(const FuncGraphPtr &graph) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (graph != nullptr && CompileGraphs::ContainMixedTarget(graph)) { |
|
|
|
if (!ms_ctx->is_multi_graph_sink()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -310,27 +309,29 @@ bool TaskEmitAction(const ResourcePtr &res) { |
|
|
|
} |
|
|
|
FuncGraphPtr func_graph = res->func_graph(); |
|
|
|
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>(); |
|
|
|
if (IsCtrlSink(func_graph)) { |
|
|
|
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); |
|
|
|
return true; |
|
|
|
} |
|
|
|
std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops; |
|
|
|
if (bc_ptr->name() == kMsConvert) { |
|
|
|
cut_list = compile::GetMsNonlinearOps(); |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list); |
|
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(context_ptr); |
|
|
|
if (CompileGraphs::ContainMixedTarget(func_graph)) { |
|
|
|
bc_ptr->set_is_multi_graph_sink(false); |
|
|
|
context_ptr->set_is_multi_graph_sink(false); |
|
|
|
context_ptr->set_loop_sink_flag(false); |
|
|
|
} else if (context_ptr->execution_mode() != kPynativeMode) { |
|
|
|
std::string device_target = context_ptr->device_target(); |
|
|
|
if (device_target == kAscendDevice) { |
|
|
|
bc_ptr->set_is_multi_graph_sink(true); |
|
|
|
context_ptr->set_is_multi_graph_sink(true); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (IsCtrlSink()) { |
|
|
|
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); |
|
|
|
return true; |
|
|
|
} |
|
|
|
std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops; |
|
|
|
if (bc_ptr->name() == kMsConvert) { |
|
|
|
cut_list = compile::GetMsNonlinearOps(); |
|
|
|
} |
|
|
|
std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list); |
|
|
|
res->results()[kOutput] = compile->CompileAndLink(func_graph); |
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -340,11 +341,10 @@ bool ExecuteAction(const ResourcePtr &res) { |
|
|
|
MS_LOG(EXCEPTION) << "Execute args error"; |
|
|
|
} |
|
|
|
|
|
|
|
if (IsCtrlSink(nullptr)) { |
|
|
|
if (IsCtrlSink()) { |
|
|
|
if (!res->results()[kOutput].is<GraphId>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Execute args error"; |
|
|
|
} |
|
|
|
|
|
|
|
auto graph_id = res->results()[kOutput].cast<GraphId>(); |
|
|
|
std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>(); |
|
|
|
std::shared_ptr<compile::MsBackend> msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr); |
|
|
|
|