Browse Source

Bugfix: fix the error of missing ge model in dynamic shape scene

pull/817/head
lichun 5 years ago
parent
commit
281f3ffbdc
1 changed files with 38 additions and 20 deletions
  1. +38
    -20
      ge/generator/ge_generator.cc

+ 38
- 20
ge/generator/ge_generator.cc View File

@@ -546,6 +546,38 @@ bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) {
return true;
}

static Status SetModelNameForDump(GeRootModelPtr ge_root_model) {
bool is_unknown_shape = false;
Status ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
if (ret != SUCCESS) {
GELOGE(FAILED, "Check root model is unknown shape failed.");
return FAILED;
}
GeModelPtr model_root = nullptr;
if (is_unknown_shape) {
model_root = make_shared<GeModel>();
GE_CHECK_NOTNULL(model_root);
model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph()));
ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root);
}

ModelHelper model_helper;
string model_name = "";
GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(),
model_name);
if (name_ret != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"});
GELOGE(FAILED, "Get model_name failed. Param --output is invalid.");
return PARAM_INVALID;
}
map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge model cannot be null.");
ge_model->SetName(model_name);
return SUCCESS;
}

Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs,
ModelBufferData &model, bool is_offline) {
rtContext_t ctx = nullptr;
@@ -554,7 +586,6 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr
GELOGD("Current ctx is null.");
ctx = nullptr;
}

GeRootModelPtr ge_root_model = nullptr;
GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID);
impl_->is_offline_ = is_offline;
@@ -580,20 +611,10 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr
}

GE_CHECK_NOTNULL(ge_root_model);
GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
ModelHelper model_helper;
string model_name = "";
Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(),
model_name);
if (name_ret != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"});
GELOGE(FAILED, "Get model_name failed. Param --output is invalid.");
return PARAM_INVALID;
ret = SetModelNameForDump(ge_root_model);
if (ret != SUCCESS) {
return ret;
}
map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model cannot be null");
ge_model->SetName(model_name);
ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model);
if (ret != SUCCESS) {
GELOGE(ret, "Save model failed");
@@ -602,11 +623,9 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr
}
return ret;
}

if (ctx != nullptr) {
(void)rtCtxSetCurrent(ctx);
}

return SUCCESS;
}

@@ -827,13 +846,12 @@ Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootMo
"ge root model has no sub model")
GeModelPtr model_root = nullptr;
if (is_unknown_shape) {
model_root = make_shared<GeModel>();
model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph()));
ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root);
model_root->SetName(ge_root_model->GetRootGraph()->GetName());
auto name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
model_root = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
} else {
model_root = ge_root_model->GetSubgraphInstanceNameToModel().begin()->second;
}
GE_CHECK_NOTNULL(model_root);
// set atc version
if (!SetAtcVersionInfo(*(model_root.get()))) {
GELOGW("SetPackageVersionInfo of atc failed!");


Loading…
Cancel
Save