|
|
|
@@ -177,19 +177,19 @@ void ConvertObjectToTensors(const py::dict &dict, TensorOrderMap *const tensors) |
|
|
|
bool AddDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const py::dict &init_params, |
|
|
|
const std::string &phase, const py::object &broadcast_params) { |
|
|
|
FuncGraphPtr anf_graph = info.at(phase)->func_graph; |
|
|
|
DfGraphConvertor convertor(anf_graph); |
|
|
|
DfGraphConvertor converter(anf_graph); |
|
|
|
|
|
|
|
size_t pos = phase.find('.'); |
|
|
|
std::string net_id = ((pos == std::string::npos || pos == phase.size() - 1) ? phase : phase.substr(pos + 1)); |
|
|
|
std::string phase_prefix = phase.substr(0, pos); |
|
|
|
if (phase_prefix == "export") { |
|
|
|
MS_LOG(INFO) << "Set DfGraphConvertor training : false"; |
|
|
|
convertor.set_training(false); |
|
|
|
converter.set_training(false); |
|
|
|
} |
|
|
|
|
|
|
|
TensorOrderMap init_tensors{}; |
|
|
|
ConvertObjectToTensors(init_params, &init_tensors); |
|
|
|
(void)convertor.ConvertAllNode().InitParam(init_tensors).BuildGraph(); |
|
|
|
(void)converter.ConvertAllNode().InitParam(init_tensors).BuildGraph(); |
|
|
|
|
|
|
|
if (!broadcast_params.is_none()) { |
|
|
|
if (!py::isinstance<py::dict>(broadcast_params)) { |
|
|
|
@@ -198,38 +198,38 @@ bool AddDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const py::di |
|
|
|
} |
|
|
|
py::dict broadcast = broadcast_params.cast<py::dict>(); |
|
|
|
if (broadcast.empty()) { |
|
|
|
(void)convertor.GenerateBroadcastGraph(init_tensors); |
|
|
|
(void)converter.GenerateBroadcastGraph(init_tensors); |
|
|
|
} else { |
|
|
|
TensorOrderMap broadcast_tensors{}; |
|
|
|
ConvertObjectToTensors(broadcast, &broadcast_tensors); |
|
|
|
(void)convertor.GenerateBroadcastGraph(broadcast_tensors); |
|
|
|
(void)converter.GenerateBroadcastGraph(broadcast_tensors); |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Generate broadcast graph with params and broadcast_empty is " << broadcast.empty(); |
|
|
|
} |
|
|
|
|
|
|
|
(void)convertor.GenerateCheckpointGraph(); |
|
|
|
if (convertor.ErrCode() != 0) { |
|
|
|
(void)converter.GenerateCheckpointGraph(); |
|
|
|
if (converter.ErrCode() != 0) { |
|
|
|
DfGraphManager::GetInstance().ClearGraph(); |
|
|
|
MS_LOG(ERROR) << "Convert df graph failed, err:" << convertor.ErrCode(); |
|
|
|
MS_LOG(ERROR) << "Convert df graph failed, err:" << converter.ErrCode(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { |
|
|
|
convertor.DrawComputeGraph(GetSaveGraphsPathName("ge_graph.dot")); // for debug |
|
|
|
convertor.DrawInitGraph(GetSaveGraphsPathName("init_graph.dot")); // for debug |
|
|
|
convertor.DrawSaveCheckpointGraph(GetSaveGraphsPathName("save_checkpoint_graph.dot")); // for debug |
|
|
|
converter.DrawComputeGraph(GetSaveGraphsPathName("ge_graph.dot")); // for debug |
|
|
|
converter.DrawInitGraph(GetSaveGraphsPathName("init_graph.dot")); // for debug |
|
|
|
converter.DrawSaveCheckpointGraph(GetSaveGraphsPathName("save_checkpoint_graph.dot")); // for debug |
|
|
|
} |
|
|
|
std::string init_graph = "init_subgraph." + net_id; |
|
|
|
std::string checkpoint_name = "save." + net_id; |
|
|
|
if (phase.find("train") != std::string::npos) { |
|
|
|
(void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph(), {{"ge.exec.variable_acc", "1"}}); |
|
|
|
(void)DfGraphManager::GetInstance().AddGraph(phase, converter.GetComputeGraph(), {{"ge.exec.variable_acc", "1"}}); |
|
|
|
} else { |
|
|
|
(void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph()); |
|
|
|
(void)DfGraphManager::GetInstance().AddGraph(phase, converter.GetComputeGraph()); |
|
|
|
} |
|
|
|
(void)DfGraphManager::GetInstance().AddGraph(init_graph, convertor.GetInitGraph()); |
|
|
|
(void)DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, convertor.GetBroadcastGraph()); |
|
|
|
(void)DfGraphManager::GetInstance().AddGraph(init_graph, converter.GetInitGraph()); |
|
|
|
(void)DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, converter.GetBroadcastGraph()); |
|
|
|
|
|
|
|
Status ret = DfGraphManager::GetInstance().AddGraph(checkpoint_name, convertor.GetSaveCheckpointGraph()); |
|
|
|
Status ret = DfGraphManager::GetInstance().AddGraph(checkpoint_name, converter.GetSaveCheckpointGraph()); |
|
|
|
if (ret == Status::SUCCESS) { |
|
|
|
DfGraphManager::GetInstance().SetAnfGraph(checkpoint_name, anf_graph); |
|
|
|
} |
|
|
|
@@ -529,8 +529,10 @@ void ExportDFGraph(const std::string &file_name, const std::string &phase) { |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
(void)ge_graph->SaveToFile(file_name); |
|
|
|
MS_LOG(DEBUG) << "Export graph end."; |
|
|
|
if (ge_graph->SaveToFile(file_name) != 0) { |
|
|
|
MS_LOG(EXCEPTION) << "Export air model failed."; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Export air model finish."; |
|
|
|
} |
|
|
|
} // namespace pipeline |
|
|
|
} // namespace mindspore |