|
|
|
@@ -269,13 +269,41 @@ bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePa |
|
|
|
|
|
|
|
bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); } |
|
|
|
|
|
|
|
static bool IsCtrlSink() { |
|
|
|
auto ms_ctx = MsContext::GetInstance(); |
|
|
|
std::string device_target = ms_ctx->device_target(); |
|
|
|
if (device_target != kAscendDevice) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (!ms_ctx->enable_task_sink()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
char *enable_ctrl_sink = std::getenv("ENABLE_CTRL_SINK"); |
|
|
|
if (enable_ctrl_sink == nullptr) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
std::string enable_ctrl_sink_str(enable_ctrl_sink); |
|
|
|
if (enable_ctrl_sink_str == "0") { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool TaskEmitAction(const ResourcePtr &res) { |
|
|
|
if (res->func_graph() == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "TaskEmit args error"; |
|
|
|
} |
|
|
|
FuncGraphPtr func_graph = res->func_graph(); |
|
|
|
|
|
|
|
auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>(); |
|
|
|
|
|
|
|
if (IsCtrlSink()) { |
|
|
|
res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops; |
|
|
|
if (bc_ptr->name() == kMsConvert) { |
|
|
|
cut_list = compile::GetMsNonlinearOps(); |
|
|
|
@@ -286,10 +314,31 @@ bool TaskEmitAction(const ResourcePtr &res) { |
|
|
|
} |
|
|
|
|
|
|
|
bool ExecuteAction(const ResourcePtr &res) { |
|
|
|
if (res->results().count(kOutput) == 0 || !res->results()[kOutput].is<compile::FinalVMPtr>()) { |
|
|
|
if (res->results().count(kOutput) == 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Execute args error"; |
|
|
|
} |
|
|
|
|
|
|
|
if (IsCtrlSink()) { |
|
|
|
if (!res->results()[kOutput].is<GraphId>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Execute args error"; |
|
|
|
} |
|
|
|
|
|
|
|
auto graph_id = res->results()[kOutput].cast<GraphId>(); |
|
|
|
auto bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::MsBackend>>(); |
|
|
|
compile::VmEvalFuncPtr run = |
|
|
|
std::make_shared<compile::VmEvalFunc>([&bc_ptr, graph_id](const VectorRef &args) -> BaseRef { |
|
|
|
MS_LOG(INFO) << "Execute args size" << args.size(); |
|
|
|
auto outs = bc_ptr->RunGraph(graph_id, args); |
|
|
|
MS_LOG(DEBUG) << "out size" << outs.size(); |
|
|
|
return outs[0]; |
|
|
|
}); |
|
|
|
res->results()[kOutput] = run; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
if (!res->results()[kOutput].is<compile::FinalVMPtr>()) { |
|
|
|
MS_LOG(EXCEPTION) << "Execute args error"; |
|
|
|
} |
|
|
|
compile::FinalVMPtr vm = res->results()[kOutput].cast<compile::FinalVMPtr>(); |
|
|
|
if (vm == nullptr) { |
|
|
|
MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM"; |
|
|
|
|