diff --git a/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc b/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc index 1e0ed1e5b2..ce06a45197 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline_ge.cc @@ -177,19 +177,19 @@ void ConvertObjectToTensors(const py::dict &dict, TensorOrderMap *const tensors) bool AddDFGraph(const std::map &info, const py::dict &init_params, const std::string &phase, const py::object &broadcast_params) { FuncGraphPtr anf_graph = info.at(phase)->func_graph; - DfGraphConvertor convertor(anf_graph); + DfGraphConvertor converter(anf_graph); size_t pos = phase.find('.'); std::string net_id = ((pos == std::string::npos || pos == phase.size() - 1) ? phase : phase.substr(pos + 1)); std::string phase_prefix = phase.substr(0, pos); if (phase_prefix == "export") { MS_LOG(INFO) << "Set DfGraphConvertor training : false"; - convertor.set_training(false); + converter.set_training(false); } TensorOrderMap init_tensors{}; ConvertObjectToTensors(init_params, &init_tensors); - (void)convertor.ConvertAllNode().InitParam(init_tensors).BuildGraph(); + (void)converter.ConvertAllNode().InitParam(init_tensors).BuildGraph(); if (!broadcast_params.is_none()) { if (!py::isinstance(broadcast_params)) { @@ -198,38 +198,38 @@ bool AddDFGraph(const std::map &info, const py::di } py::dict broadcast = broadcast_params.cast(); if (broadcast.empty()) { - (void)convertor.GenerateBroadcastGraph(init_tensors); + (void)converter.GenerateBroadcastGraph(init_tensors); } else { TensorOrderMap broadcast_tensors{}; ConvertObjectToTensors(broadcast, &broadcast_tensors); - (void)convertor.GenerateBroadcastGraph(broadcast_tensors); + (void)converter.GenerateBroadcastGraph(broadcast_tensors); } MS_LOG(INFO) << "Generate broadcast graph with params and broadcast_empty is " << broadcast.empty(); } - (void)convertor.GenerateCheckpointGraph(); - if (convertor.ErrCode() != 0) { + (void)converter.GenerateCheckpointGraph(); + if (converter.ErrCode() != 0) { DfGraphManager::GetInstance().ClearGraph(); - MS_LOG(ERROR) << "Convert df graph failed, err:" << convertor.ErrCode(); + MS_LOG(ERROR) << "Convert df graph failed, err:" << converter.ErrCode(); return false; } if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { - convertor.DrawComputeGraph(GetSaveGraphsPathName("ge_graph.dot")); // for debug - convertor.DrawInitGraph(GetSaveGraphsPathName("init_graph.dot")); // for debug - convertor.DrawSaveCheckpointGraph(GetSaveGraphsPathName("save_checkpoint_graph.dot")); // for debug + converter.DrawComputeGraph(GetSaveGraphsPathName("ge_graph.dot")); // for debug + converter.DrawInitGraph(GetSaveGraphsPathName("init_graph.dot")); // for debug + converter.DrawSaveCheckpointGraph(GetSaveGraphsPathName("save_checkpoint_graph.dot")); // for debug } std::string init_graph = "init_subgraph." + net_id; std::string checkpoint_name = "save." + net_id; if (phase.find("train") != std::string::npos) { - (void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph(), {{"ge.exec.variable_acc", "1"}}); + (void)DfGraphManager::GetInstance().AddGraph(phase, converter.GetComputeGraph(), {{"ge.exec.variable_acc", "1"}}); } else { - (void)DfGraphManager::GetInstance().AddGraph(phase, convertor.GetComputeGraph()); + (void)DfGraphManager::GetInstance().AddGraph(phase, converter.GetComputeGraph()); } - (void)DfGraphManager::GetInstance().AddGraph(init_graph, convertor.GetInitGraph()); - (void)DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, convertor.GetBroadcastGraph()); + (void)DfGraphManager::GetInstance().AddGraph(init_graph, converter.GetInitGraph()); + (void)DfGraphManager::GetInstance().AddGraph(BROADCAST_GRAPH_NAME, converter.GetBroadcastGraph()); - Status ret = DfGraphManager::GetInstance().AddGraph(checkpoint_name, convertor.GetSaveCheckpointGraph()); + Status ret = DfGraphManager::GetInstance().AddGraph(checkpoint_name, converter.GetSaveCheckpointGraph()); if (ret == Status::SUCCESS) { DfGraphManager::GetInstance().SetAnfGraph(checkpoint_name, anf_graph); } @@ -529,8 +529,10 @@ void ExportDFGraph(const std::string &file_name, const std::string &phase) { return; } - (void)ge_graph->SaveToFile(file_name); - MS_LOG(DEBUG) << "Export graph end."; + if (ge_graph->SaveToFile(file_name) != 0) { + MS_LOG(EXCEPTION) << "Export air model failed."; + } + MS_LOG(INFO) << "Export air model finish."; } } // namespace pipeline } // namespace mindspore diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 4babfd256c..a0eb609112 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -512,7 +512,7 @@ class _Executor: if "export" not in phase: init_phase = "init_subgraph" + "." + str(obj.create_time) _exec_init_graph(obj, init_phase) - elif not enable_ge and "export" in phase: + elif "export" in phase: self._build_data_graph(obj, phase) elif BROADCAST_PHASE not in phase and _get_parameter_broadcast(): _parameter_broadcast(obj, auto_parallel_mode)