diff --git a/ge/common/helper/model_helper.cc b/ge/common/helper/model_helper.cc index bc4205f8..0fc51518 100644 --- a/ge/common/helper/model_helper.cc +++ b/ge/common/helper/model_helper.cc @@ -875,13 +875,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::GetModelNam return SUCCESS; } -/*FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelTool::GetModelInfoFromOm(const char *model_file, +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelTool::GetModelInfoFromOm(const char *model_file, ge::proto::ModelDef &model_def, uint32_t &modeldef_size) { GE_CHECK_NOTNULL(model_file); ge::ModelData model; int32_t priority = 0; - Status ret = ModelParserBase::LoadFromFile(model_file, "", priority, model); if (ret != SUCCESS) { GELOGE(ret, "LoadFromFile failed."); @@ -893,6 +892,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::GetModelNam model.model_data = nullptr; } }; + GE_MAKE_GUARD(release, callback); uint8_t *model_data = nullptr; uint32_t model_len = 0; @@ -905,17 +905,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::GetModelNam return ret; } - OmFileLoadHelper omFileLoadHelper; - ret = omFileLoadHelper.Init(model_data, model_len); - if (ret != ge::GRAPH_SUCCESS) { + OmFileLoadHelper om_load_helper; + ret = om_load_helper.Init(model_data, model_len); + if (ret != SUCCESS) { ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"Om file init failed"}); GELOGE(ge::FAILED, "Om file init failed."); return ret; } ModelPartition ir_part; - ret = omFileLoadHelper.GetModelPartition(MODEL_DEF, ir_part); - if (ret != ge::GRAPH_SUCCESS) { + ret = om_load_helper.GetModelPartition(MODEL_DEF, ir_part); + if (ret != SUCCESS) { ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"Get model part failed"}); GELOGE(ge::FAILED, "Get model part failed."); return ret; @@ -968,5 +968,5 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelTool::GetModelInfoF GELOGE(FAILED, "ParseFromString failed. exception message : %s", e.what()); return FAILED; } -}*/ +} } // namespace ge diff --git a/ge/init/gelib.cc b/ge/init/gelib.cc index 1463f25e..19085c19 100644 --- a/ge/init/gelib.cc +++ b/ge/init/gelib.cc @@ -533,7 +533,7 @@ void GELib::RollbackInit() { VarManagerPool::Instance().Destory(); } -/*Status GEInit::Initialize(const map &options) { +Status GEInit::Initialize(const map &options) { Status ret = SUCCESS; std::shared_ptr instance_ptr = ge::GELib::GetInstance(); if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { @@ -543,10 +543,13 @@ void GELib::RollbackInit() { } Status GEInit::Finalize() { - return GELib::GetInstance()->Finalize(); + std::shared_ptr instance_ptr = ge::GELib::GetInstance(); + if (instance_ptr != nullptr) { + return instance_ptr->Finalize(); + } } string GEInit::GetPath() { return GELib::GetPath(); -}*/ +} } // namespace ge diff --git a/tests/ut/ge/graph/load/model_helper_unittest.cc b/tests/ut/ge/graph/load/model_helper_unittest.cc index 455285bf..1d2127cb 100644 --- a/tests/ut/ge/graph/load/model_helper_unittest.cc +++ b/tests/ut/ge/graph/load/model_helper_unittest.cc @@ -8,7 +8,7 @@ * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, + * distributed under the License is distributed on an "AS I#include "common/model_parser/base.h"S" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. @@ -18,6 +18,8 @@ #define private public #define protected public #include "framework/common/helper/model_helper.h" +#include "framework/omg/model_tool.h" +#include "framework/omg/ge_init.h" #include "ge/model/ge_model.h" #undef private #undef protected @@ -49,4 +51,21 @@ TEST_F(UtestModelHelper, save_size_to_modeldef) ModelHelper model_helper; EXPECT_EQ(SUCCESS, model_helper.SaveSizeToModelDef(ge_model)); } +TEST_F(UtestModelHelper, atc_test) +{ + ge::proto::ModelDef model_def; + uint32_t modeldef_size = 0; + + GEInit::Finalize(); + char buffer[1024]; + getcwd(buffer, 1024); + printf("%s", buffer); + string path=buffer; + string file_path=path + "Makefile"; + + ModelTool::GetModelInfoFromOm(file_path.c_str(), model_def, modeldef_size); + ModelTool::GetModelInfoFromOm("123.om", model_def, modeldef_size); + ModelTool::GetModelInfoFromPbtxt(file_path.c_str(), model_def); + ModelTool::GetModelInfoFromPbtxt("123.pbtxt", model_def); +} } // namespace ge