|
|
|
@@ -181,13 +181,6 @@ FuncGraphPtr ExecutorPy::GetFuncGraph(const std::string &phase) { |
|
|
|
return info_[phase]->func_graph; |
|
|
|
} |
|
|
|
|
|
|
|
std::size_t ExecutorPy::ArgListSize(const std::string &phase) { |
|
|
|
if (info_.count(phase) == 0) { |
|
|
|
MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); |
|
|
|
} |
|
|
|
return info_[phase]->arg_list_size; |
|
|
|
} |
|
|
|
|
|
|
|
compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string &phase) { |
|
|
|
ResourcePtr res = GetResource(phase); |
|
|
|
MS_EXCEPTION_IF_NULL(res); |
|
|
|
@@ -702,8 +695,9 @@ void Pipeline::Run() { |
|
|
|
} |
|
|
|
|
|
|
|
void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) { |
|
|
|
MS_EXCEPTION_IF_NULL(arg_list); |
|
|
|
std::size_t size = args.size(); |
|
|
|
|
|
|
|
bool arg_list_inited = !arg_list->empty(); |
|
|
|
for (std::size_t i = 0; i < size; i++) { |
|
|
|
py::object arg = args[i]; |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
@@ -715,7 +709,14 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef |
|
|
|
if (!succ) { |
|
|
|
MS_LOG(EXCEPTION) << "The " << i << "th arg convert failed."; |
|
|
|
} |
|
|
|
arg_list->push_back(converted); |
|
|
|
if (!arg_list_inited) { |
|
|
|
arg_list->push_back(converted); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (i >= arg_list->size()) { |
|
|
|
MS_LOG(EXCEPTION) << "i:" << i << " output of range:" << arg_list->size(); |
|
|
|
} |
|
|
|
(*arg_list)[i] = converted; |
|
|
|
} |
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(res); |
|
|
|
@@ -792,20 +793,23 @@ py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) { |
|
|
|
return args; |
|
|
|
} |
|
|
|
#endif |
|
|
|
std::size_t full_arg_size = ArgListSize(phase_s); |
|
|
|
if (size > full_arg_size) { |
|
|
|
MS_LOG(WARNING) << "The arg num : size = " << size << ". full_arg_size = " << full_arg_size; |
|
|
|
auto iter = info_.find(phase_s); |
|
|
|
if (iter == info_.end()) { |
|
|
|
MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase_s); |
|
|
|
} |
|
|
|
VectorRef arg_list; |
|
|
|
ProcessVmArg(args, phase_s, &arg_list); |
|
|
|
|
|
|
|
auto &execute_info = iter->second; |
|
|
|
MS_EXCEPTION_IF_NULL(execute_info); |
|
|
|
if (size > execute_info->arg_list_size) { |
|
|
|
MS_LOG(WARNING) << "The arg num : size = " << size << ". full_arg_size = " << execute_info->arg_list_size; |
|
|
|
} |
|
|
|
ProcessVmArg(args, phase_s, &execute_info->arg_list); |
|
|
|
// Start to run phase. |
|
|
|
compile::VmEvalFuncPtr run = GetVmEvalFunc(phase_s); |
|
|
|
if (run == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Can't find run graph func for " << phase_s; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Eval run" << backend; |
|
|
|
BaseRef value = (*run)(arg_list); |
|
|
|
BaseRef value = (*run)(execute_info->arg_list); |
|
|
|
MS_LOG(DEBUG) << "Run end"; |
|
|
|
return BaseRefToPyData(value); |
|
|
|
} |
|
|
|
|