Browse Source

fix

pull/1177/head
wjm 4 years ago
parent
commit
4fdd9a72fe
3 changed files with 34 additions and 12 deletions
  1. +8
    -8
      ge/common/helper/model_helper.cc
  2. +6
    -3
      ge/init/gelib.cc
  3. +20
    -1
      tests/ut/ge/graph/load/model_helper_unittest.cc

+ 8
- 8
ge/common/helper/model_helper.cc View File

@@ -875,13 +875,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::GetModelNam
return SUCCESS;
}

/*FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelTool::GetModelInfoFromOm(const char *model_file,
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelTool::GetModelInfoFromOm(const char *model_file,
ge::proto::ModelDef &model_def,
uint32_t &modeldef_size) {
GE_CHECK_NOTNULL(model_file);
ge::ModelData model;
int32_t priority = 0;

Status ret = ModelParserBase::LoadFromFile(model_file, "", priority, model);
if (ret != SUCCESS) {
GELOGE(ret, "LoadFromFile failed.");
@@ -893,6 +892,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::GetModelNam
model.model_data = nullptr;
}
};
GE_MAKE_GUARD(release, callback);

uint8_t *model_data = nullptr;
uint32_t model_len = 0;
@@ -905,17 +905,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::GetModelNam
return ret;
}

OmFileLoadHelper omFileLoadHelper;
ret = omFileLoadHelper.Init(model_data, model_len);
if (ret != ge::GRAPH_SUCCESS) {
OmFileLoadHelper om_load_helper;
ret = om_load_helper.Init(model_data, model_len);
if (ret != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"Om file init failed"});
GELOGE(ge::FAILED, "Om file init failed.");
return ret;
}

ModelPartition ir_part;
ret = omFileLoadHelper.GetModelPartition(MODEL_DEF, ir_part);
if (ret != ge::GRAPH_SUCCESS) {
ret = om_load_helper.GetModelPartition(MODEL_DEF, ir_part);
if (ret != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E19021", {"reason"}, {"Get model part failed"});
GELOGE(ge::FAILED, "Get model part failed.");
return ret;
@@ -968,5 +968,5 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelTool::GetModelInfoF
GELOGE(FAILED, "ParseFromString failed. exception message : %s", e.what());
return FAILED;
}
}*/
}
} // namespace ge

+ 6
- 3
ge/init/gelib.cc View File

@@ -533,7 +533,7 @@ void GELib::RollbackInit() {
VarManagerPool::Instance().Destory();
}

/*Status GEInit::Initialize(const map<string, string> &options) {
Status GEInit::Initialize(const map<string, string> &options) {
Status ret = SUCCESS;
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
@@ -543,10 +543,13 @@ void GELib::RollbackInit() {
}

Status GEInit::Finalize() {
return GELib::GetInstance()->Finalize();
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr != nullptr) {
return instance_ptr->Finalize();
}
}

string GEInit::GetPath() {
return GELib::GetPath();
}*/
}
} // namespace ge

+ 20
- 1
tests/ut/ge/graph/load/model_helper_unittest.cc View File

@@ -8,7 +8,7 @@
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* distributed under the License is distributed on an "AS I#include "common/model_parser/base.h"S" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
@@ -18,6 +18,8 @@
#define private public
#define protected public
#include "framework/common/helper/model_helper.h"
#include "framework/omg/model_tool.h"
#include "framework/omg/ge_init.h"
#include "ge/model/ge_model.h"
#undef private
#undef protected
@@ -49,4 +51,21 @@ TEST_F(UtestModelHelper, save_size_to_modeldef)
ModelHelper model_helper;
EXPECT_EQ(SUCCESS, model_helper.SaveSizeToModelDef(ge_model));
}
TEST_F(UtestModelHelper, atc_test)
{
ge::proto::ModelDef model_def;
uint32_t modeldef_size = 0;
GEInit::Finalize();
char buffer[1024];
getcwd(buffer, 1024);
printf("%s", buffer);
string path=buffer;
string file_path=path + "Makefile";
ModelTool::GetModelInfoFromOm(file_path.c_str(), model_def, modeldef_size);
ModelTool::GetModelInfoFromOm("123.om", model_def, modeldef_size);
ModelTool::GetModelInfoFromPbtxt(file_path.c_str(), model_def);
ModelTool::GetModelInfoFromPbtxt("123.pbtxt", model_def);
}
} // namespace ge

Loading…
Cancel
Save