From 281f3ffbdcf5d6063bc067101da6e5a5243ee133 Mon Sep 17 00:00:00 2001 From: lichun Date: Wed, 30 Dec 2020 14:12:35 +0800 Subject: [PATCH] Bugfix: fix the error of missing ge model in dynamic shape scene --- ge/generator/ge_generator.cc | 58 +++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index 016f9ef2..30c75b33 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -546,6 +546,38 @@ bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) { return true; } +static Status SetModelNameForDump(GeRootModelPtr ge_root_model) { + bool is_unknown_shape = false; + Status ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape); + if (ret != SUCCESS) { + GELOGE(FAILED, "Check root model is unknown shape failed."); + return FAILED; + } + GeModelPtr model_root = nullptr; + if (is_unknown_shape) { + model_root = make_shared(); + GE_CHECK_NOTNULL(model_root); + model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph())); + ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root); + } + + ModelHelper model_helper; + string model_name = ""; + GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); + Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(), + model_name); + if (name_ret != SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"}); + GELOGE(FAILED, "Get model_name failed. Param --output is invalid."); + return PARAM_INVALID; + } + map name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); + GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; + GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge model cannot be null."); + ge_model->SetName(model_name); + return SUCCESS; +} + Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector &inputs, ModelBufferData &model, bool is_offline) { rtContext_t ctx = nullptr; @@ -554,7 +586,6 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr GELOGD("Current ctx is null."); ctx = nullptr; } - GeRootModelPtr ge_root_model = nullptr; GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); impl_->is_offline_ = is_offline; @@ -580,20 +611,10 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr } GE_CHECK_NOTNULL(ge_root_model); - GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); - ModelHelper model_helper; - string model_name = ""; - Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(), - model_name); - if (name_ret != SUCCESS) { - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"}); - GELOGE(FAILED, "Get model_name failed. Param --output is invalid."); - return PARAM_INVALID; + ret = SetModelNameForDump(ge_root_model); + if (ret != SUCCESS) { + return ret; } - map name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); - GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; - GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model cannot be null"); - ge_model->SetName(model_name); ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model); if (ret != SUCCESS) { GELOGE(ret, "Save model failed"); @@ -602,11 +623,9 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr } return ret; } - if (ctx != nullptr) { (void)rtCtxSetCurrent(ctx); } - return SUCCESS; } @@ -827,13 +846,12 @@ Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootMo "ge root model has no sub model") GeModelPtr model_root = nullptr; if (is_unknown_shape) { - model_root = make_shared(); - model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph())); - ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root); - model_root->SetName(ge_root_model->GetRootGraph()->GetName()); + auto name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); + model_root = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; } else { model_root = ge_root_model->GetSubgraphInstanceNameToModel().begin()->second; } + GE_CHECK_NOTNULL(model_root); // set atc version if (!SetAtcVersionInfo(*(model_root.get()))) { GELOGW("SetPackageVersionInfo of atc failed!");