|
|
|
@@ -227,10 +227,12 @@ bool AddDFGraph(const std::map<std::string, ExecutorInfoPtr>& info, const py::di |
|
|
|
(void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph()); |
|
|
|
} |
|
|
|
(void)DfGraphManager::GetInstance().AddGraph(init_graph, convertor.GetInitGraph()); |
|
|
|
(void)DfGraphManager::GetInstance().AddGraph(checkpoint_name, convertor.GetSaveCheckpointGraph()); |
|
|
|
(void)DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, convertor.GetBroadcastGraph()); |
|
|
|
|
|
|
|
DfGraphManager::GetInstance().SetAnfGraph(checkpoint_name, anf_graph); |
|
|
|
Status ret = DfGraphManager::GetInstance().AddGraph(checkpoint_name, convertor.GetSaveCheckpointGraph()); |
|
|
|
if (ret == Status::SUCCESS) { |
|
|
|
DfGraphManager::GetInstance().SetAnfGraph(checkpoint_name, anf_graph); |
|
|
|
} |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -389,8 +391,7 @@ std::shared_ptr<py::object> DoExecGraph(const FuncGraphPtr& graph, const std::ve |
|
|
|
const std::string& phase) { |
|
|
|
std::vector<GeTensorPtr> ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW); |
|
|
|
if (ge_tensors.size() != inputs.size()) { |
|
|
|
MS_LOG(ERROR) << "Args convert to ge tensor error"; |
|
|
|
return nullptr; |
|
|
|
MS_LOG(EXCEPTION) << "Convert me args to ge tensor error."; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<GeTensorPtr> ge_outputs; |
|
|
|
@@ -401,8 +402,7 @@ std::shared_ptr<py::object> DoExecGraph(const FuncGraphPtr& graph, const std::ve |
|
|
|
auto graph_runner = DfGraphManager::GetInstance().GetGraphRunner(); |
|
|
|
|
|
|
|
if (graph_runner == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Can not found GraphRunner"; |
|
|
|
return nullptr; |
|
|
|
MS_LOG(EXCEPTION) << "Can not found GraphRunner."; |
|
|
|
} |
|
|
|
|
|
|
|
{ |
|
|
|
@@ -419,7 +419,7 @@ std::shared_ptr<py::object> DoExecGraph(const FuncGraphPtr& graph, const std::ve |
|
|
|
|
|
|
|
std::vector<MeTensorPtr> me_outputs = TransformUtil::ConvertGeTensors(ge_outputs); |
|
|
|
if (me_outputs.size() != ge_outputs.size()) { |
|
|
|
MS_LOG(ERROR) << "Convert output Ge tensor to Me tensor failed"; |
|
|
|
MS_LOG(WARNING) << "Convert output Ge tensor to Me tensor failed"; |
|
|
|
} |
|
|
|
|
|
|
|
py::tuple outputs(me_outputs.size()); |
|
|
|
|