|
|
|
@@ -556,6 +556,38 @@ bool GeGenerator::Impl::SetOmSystemInfo(AttrHolder &obj) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
static Status SetModelNameForDump(GeRootModelPtr ge_root_model) { |
|
|
|
bool is_unknown_shape = false; |
|
|
|
Status ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(FAILED, "Check root model is unknown shape failed."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
GeModelPtr model_root = nullptr; |
|
|
|
if (is_unknown_shape) { |
|
|
|
model_root = make_shared<GeModel>(); |
|
|
|
GE_CHECK_NOTNULL(model_root); |
|
|
|
model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph())); |
|
|
|
ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root); |
|
|
|
} |
|
|
|
|
|
|
|
ModelHelper model_helper; |
|
|
|
string model_name; |
|
|
|
GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); |
|
|
|
Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(), |
|
|
|
model_name); |
|
|
|
if (name_ret != SUCCESS) { |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"}); |
|
|
|
GELOGE(FAILED, "Get model_name failed. Param --output is invalid"); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); |
|
|
|
GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; |
|
|
|
GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model cannot be null."); |
|
|
|
ge_model->SetName(model_name); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, |
|
|
|
ModelBufferData &model, bool is_offline) { |
|
|
|
rtContext_t ctx = nullptr; |
|
|
|
@@ -590,20 +622,10 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr |
|
|
|
} |
|
|
|
|
|
|
|
GE_CHECK_NOTNULL(ge_root_model); |
|
|
|
GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); |
|
|
|
ModelHelper model_helper; |
|
|
|
string model_name = ""; |
|
|
|
Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(), |
|
|
|
model_name); |
|
|
|
if (name_ret != SUCCESS) { |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"}); |
|
|
|
GELOGE(FAILED, "Get model_name failed. Param --output is invalid."); |
|
|
|
return PARAM_INVALID; |
|
|
|
ret = SetModelNameForDump(ge_root_model); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
return ret; |
|
|
|
} |
|
|
|
map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); |
|
|
|
GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; |
|
|
|
GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model cannot be null"); |
|
|
|
ge_model->SetName(model_name); |
|
|
|
ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(ret, "Save model failed"); |
|
|
|
@@ -873,13 +895,12 @@ Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootMo |
|
|
|
"ge root model has no sub model") |
|
|
|
GeModelPtr model_root = nullptr; |
|
|
|
if (is_unknown_shape) { |
|
|
|
model_root = make_shared<GeModel>(); |
|
|
|
model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph())); |
|
|
|
ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root); |
|
|
|
model_root->SetName(ge_root_model->GetRootGraph()->GetName()); |
|
|
|
auto name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); |
|
|
|
model_root = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; |
|
|
|
} else { |
|
|
|
model_root = ge_root_model->GetSubgraphInstanceNameToModel().begin()->second; |
|
|
|
} |
|
|
|
GE_CHECK_NOTNULL(model_root); |
|
|
|
// set atc version |
|
|
|
if (!SetAtcVersionInfo(*(model_root.get()))) { |
|
|
|
GELOGW("SetPackageVersionInfo of atc failed!"); |
|
|
|
|