| @@ -546,6 +546,38 @@ bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| static Status SetModelNameForDump(GeRootModelPtr ge_root_model) { | |||||
| bool is_unknown_shape = false; | |||||
| Status ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "Check root model is unknown shape failed."); | |||||
| return FAILED; | |||||
| } | |||||
| GeModelPtr model_root = nullptr; | |||||
| if (is_unknown_shape) { | |||||
| model_root = make_shared<GeModel>(); | |||||
| GE_CHECK_NOTNULL(model_root); | |||||
| model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph())); | |||||
| ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root); | |||||
| } | |||||
| ModelHelper model_helper; | |||||
| string model_name = ""; | |||||
| GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); | |||||
| Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(), | |||||
| model_name); | |||||
| if (name_ret != SUCCESS) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"}); | |||||
| GELOGE(FAILED, "Get model_name failed. Param --output is invalid."); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); | |||||
| GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; | |||||
| GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge model cannot be null."); | |||||
| ge_model->SetName(model_name); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, | Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, | ||||
| ModelBufferData &model, bool is_offline) { | ModelBufferData &model, bool is_offline) { | ||||
| rtContext_t ctx = nullptr; | rtContext_t ctx = nullptr; | ||||
| @@ -554,7 +586,6 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr | |||||
| GELOGD("Current ctx is null."); | GELOGD("Current ctx is null."); | ||||
| ctx = nullptr; | ctx = nullptr; | ||||
| } | } | ||||
| GeRootModelPtr ge_root_model = nullptr; | GeRootModelPtr ge_root_model = nullptr; | ||||
| GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); | GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); | ||||
| impl_->is_offline_ = is_offline; | impl_->is_offline_ = is_offline; | ||||
| @@ -580,20 +611,10 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr | |||||
| } | } | ||||
| GE_CHECK_NOTNULL(ge_root_model); | GE_CHECK_NOTNULL(ge_root_model); | ||||
| GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); | |||||
| ModelHelper model_helper; | |||||
| string model_name = ""; | |||||
| Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(), | |||||
| model_name); | |||||
| if (name_ret != SUCCESS) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"}); | |||||
| GELOGE(FAILED, "Get model_name failed. Param --output is invalid."); | |||||
| return PARAM_INVALID; | |||||
| ret = SetModelNameForDump(ge_root_model); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | } | ||||
| map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); | |||||
| GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; | |||||
| GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model cannot be null"); | |||||
| ge_model->SetName(model_name); | |||||
| ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model); | ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Save model failed"); | GELOGE(ret, "Save model failed"); | ||||
| @@ -602,11 +623,9 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (ctx != nullptr) { | if (ctx != nullptr) { | ||||
| (void)rtCtxSetCurrent(ctx); | (void)rtCtxSetCurrent(ctx); | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -827,13 +846,12 @@ Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootMo | |||||
| "ge root model has no sub model") | "ge root model has no sub model") | ||||
| GeModelPtr model_root = nullptr; | GeModelPtr model_root = nullptr; | ||||
| if (is_unknown_shape) { | if (is_unknown_shape) { | ||||
| model_root = make_shared<GeModel>(); | |||||
| model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph())); | |||||
| ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root); | |||||
| model_root->SetName(ge_root_model->GetRootGraph()->GetName()); | |||||
| auto name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); | |||||
| model_root = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; | |||||
| } else { | } else { | ||||
| model_root = ge_root_model->GetSubgraphInstanceNameToModel().begin()->second; | model_root = ge_root_model->GetSubgraphInstanceNameToModel().begin()->second; | ||||
| } | } | ||||
| GE_CHECK_NOTNULL(model_root); | |||||
| // set atc version | // set atc version | ||||
| if (!SetAtcVersionInfo(*(model_root.get()))) { | if (!SetAtcVersionInfo(*(model_root.get()))) { | ||||
| GELOGW("SetPackageVersionInfo of atc failed!"); | GELOGW("SetPackageVersionInfo of atc failed!"); | ||||