| @@ -282,7 +282,7 @@ bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPa | |||
| bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); } | |||
| static bool IsCtrlSink(const FuncGraphPtr &graph) { | |||
| static bool IsCtrlSink() { | |||
| auto ms_ctx = MsContext::GetInstance(); | |||
| if (ms_ctx->execution_mode() != kGraphMode) { | |||
| return false; | |||
| @@ -297,10 +297,9 @@ static bool IsCtrlSink(const FuncGraphPtr &graph) { | |||
| return false; | |||
| } | |||
| if (graph != nullptr && CompileGraphs::ContainMixedTarget(graph)) { | |||
| if (!ms_ctx->is_multi_graph_sink()) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -310,27 +309,29 @@ bool TaskEmitAction(const ResourcePtr &res) { | |||
| } | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>(); | |||
| if (IsCtrlSink(func_graph)) { | |||
| res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); | |||
| return true; | |||
| } | |||
| std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops; | |||
| if (bc_ptr->name() == kMsConvert) { | |||
| cut_list = compile::GetMsNonlinearOps(); | |||
| } | |||
| std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list); | |||
| auto context_ptr = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||
| if (CompileGraphs::ContainMixedTarget(func_graph)) { | |||
| bc_ptr->set_is_multi_graph_sink(false); | |||
| context_ptr->set_is_multi_graph_sink(false); | |||
| context_ptr->set_loop_sink_flag(false); | |||
| } else if (context_ptr->execution_mode() != kPynativeMode) { | |||
| std::string device_target = context_ptr->device_target(); | |||
| if (device_target == kAscendDevice) { | |||
| bc_ptr->set_is_multi_graph_sink(true); | |||
| context_ptr->set_is_multi_graph_sink(true); | |||
| } | |||
| } | |||
| if (IsCtrlSink()) { | |||
| res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); | |||
| return true; | |||
| } | |||
| std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops; | |||
| if (bc_ptr->name() == kMsConvert) { | |||
| cut_list = compile::GetMsNonlinearOps(); | |||
| } | |||
| std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list); | |||
| res->results()[kOutput] = compile->CompileAndLink(func_graph); | |||
| return true; | |||
| } | |||
| @@ -340,11 +341,10 @@ bool ExecuteAction(const ResourcePtr &res) { | |||
| MS_LOG(EXCEPTION) << "Execute args error"; | |||
| } | |||
| if (IsCtrlSink(nullptr)) { | |||
| if (IsCtrlSink()) { | |||
| if (!res->results()[kOutput].is<GraphId>()) { | |||
| MS_LOG(EXCEPTION) << "Execute args error"; | |||
| } | |||
| auto graph_id = res->results()[kOutput].cast<GraphId>(); | |||
| std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>(); | |||
| std::shared_ptr<compile::MsBackend> msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr); | |||