| @@ -282,7 +282,7 @@ bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPa | |||||
| bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); } | bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); } | ||||
| static bool IsCtrlSink(const FuncGraphPtr &graph) { | |||||
| static bool IsCtrlSink() { | |||||
| auto ms_ctx = MsContext::GetInstance(); | auto ms_ctx = MsContext::GetInstance(); | ||||
| if (ms_ctx->execution_mode() != kGraphMode) { | if (ms_ctx->execution_mode() != kGraphMode) { | ||||
| return false; | return false; | ||||
| @@ -297,10 +297,9 @@ static bool IsCtrlSink(const FuncGraphPtr &graph) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (graph != nullptr && CompileGraphs::ContainMixedTarget(graph)) { | |||||
| if (!ms_ctx->is_multi_graph_sink()) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -310,27 +309,29 @@ bool TaskEmitAction(const ResourcePtr &res) { | |||||
| } | } | ||||
| FuncGraphPtr func_graph = res->func_graph(); | FuncGraphPtr func_graph = res->func_graph(); | ||||
| auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>(); | auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>(); | ||||
| if (IsCtrlSink(func_graph)) { | |||||
| res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); | |||||
| return true; | |||||
| } | |||||
| std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops; | |||||
| if (bc_ptr->name() == kMsConvert) { | |||||
| cut_list = compile::GetMsNonlinearOps(); | |||||
| } | |||||
| std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list); | |||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| if (CompileGraphs::ContainMixedTarget(func_graph)) { | if (CompileGraphs::ContainMixedTarget(func_graph)) { | ||||
| bc_ptr->set_is_multi_graph_sink(false); | bc_ptr->set_is_multi_graph_sink(false); | ||||
| context_ptr->set_is_multi_graph_sink(false); | |||||
| context_ptr->set_loop_sink_flag(false); | context_ptr->set_loop_sink_flag(false); | ||||
| } else if (context_ptr->execution_mode() != kPynativeMode) { | } else if (context_ptr->execution_mode() != kPynativeMode) { | ||||
| std::string device_target = context_ptr->device_target(); | std::string device_target = context_ptr->device_target(); | ||||
| if (device_target == kAscendDevice) { | if (device_target == kAscendDevice) { | ||||
| bc_ptr->set_is_multi_graph_sink(true); | bc_ptr->set_is_multi_graph_sink(true); | ||||
| context_ptr->set_is_multi_graph_sink(true); | |||||
| } | } | ||||
| } | } | ||||
| if (IsCtrlSink()) { | |||||
| res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); | |||||
| return true; | |||||
| } | |||||
| std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops; | |||||
| if (bc_ptr->name() == kMsConvert) { | |||||
| cut_list = compile::GetMsNonlinearOps(); | |||||
| } | |||||
| std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list); | |||||
| res->results()[kOutput] = compile->CompileAndLink(func_graph); | res->results()[kOutput] = compile->CompileAndLink(func_graph); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -340,11 +341,10 @@ bool ExecuteAction(const ResourcePtr &res) { | |||||
| MS_LOG(EXCEPTION) << "Execute args error"; | MS_LOG(EXCEPTION) << "Execute args error"; | ||||
| } | } | ||||
| if (IsCtrlSink(nullptr)) { | |||||
| if (IsCtrlSink()) { | |||||
| if (!res->results()[kOutput].is<GraphId>()) { | if (!res->results()[kOutput].is<GraphId>()) { | ||||
| MS_LOG(EXCEPTION) << "Execute args error"; | MS_LOG(EXCEPTION) << "Execute args error"; | ||||
| } | } | ||||
| auto graph_id = res->results()[kOutput].cast<GraphId>(); | auto graph_id = res->results()[kOutput].cast<GraphId>(); | ||||
| std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>(); | std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>(); | ||||
| std::shared_ptr<compile::MsBackend> msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr); | std::shared_ptr<compile::MsBackend> msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr); | ||||