From 45175feca9f8676c8c881b15e0b8e98e29aecf31 Mon Sep 17 00:00:00 2001 From: lichun Date: Fri, 12 Mar 2021 09:39:33 +0800 Subject: [PATCH] dynamic shape inference support --- ge/generator/ge_generator.cc | 2 +- inc/framework/generator/ge_generator.h | 2 ++ tests/ut/ge/executor/ge_executor_unittest.cc | 6 ++++++ tests/ut/ge/generator/ge_generator_unittest.cc | 8 ++++++++ 4 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index 3f934fdb..e4151af5 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -556,7 +556,7 @@ bool GeGenerator::Impl::SetOmSystemInfo(AttrHolder &obj) { return true; } -static Status SetModelNameForDump(GeRootModelPtr ge_root_model) { +Status GeGenerator::SetModelNameForDump(GeRootModelPtr ge_root_model) { bool is_unknown_shape = false; Status ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape); if (ret != SUCCESS) { diff --git a/inc/framework/generator/ge_generator.h b/inc/framework/generator/ge_generator.h index 2d7d007b..f8cc2264 100644 --- a/inc/framework/generator/ge_generator.h +++ b/inc/framework/generator/ge_generator.h @@ -29,6 +29,7 @@ #include "graph/op_desc.h" #include "graph/detail/attributes_holder.h" #include "omg/omg_inner_types.h" +#include "model/ge_root_model.h" namespace ge { class GE_FUNC_VISIBILITY GeGenerator { @@ -98,6 +99,7 @@ class GE_FUNC_VISIBILITY GeGenerator { const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, bool is_offline = true); Status CheckForSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs); + Status SetModelNameForDump(GeRootModelPtr ge_root_model); class Impl; diff --git a/tests/ut/ge/executor/ge_executor_unittest.cc b/tests/ut/ge/executor/ge_executor_unittest.cc index a98f9290..a4606320 100644 --- a/tests/ut/ge/executor/ge_executor_unittest.cc +++ b/tests/ut/ge/executor/ge_executor_unittest.cc @@ -39,4 +39,10 @@ TEST_F(UtestGeExecutor, test_single_op_exec) { EXPECT_EQ(exeutor.LoadSingleOp(model_name, model_data, nullptr, nullptr), ACL_ERROR_GE_INTERNAL_ERROR); EXPECT_EQ(exeutor.LoadDynamicSingleOp(model_name, model_data, nullptr, nullptr), PARAM_INVALID); } + +TEST_F(UtestGeExecutor, test_ge_initialize) { + GeExecutor executor; + EXPECT_EQ(executor.Initialize(), SUCCESS); + EXPECT_EQ(executor.Initialize(), SUCCESS); +} } // namespace ge \ No newline at end of file diff --git a/tests/ut/ge/generator/ge_generator_unittest.cc b/tests/ut/ge/generator/ge_generator_unittest.cc index 3daa5592..215aa742 100644 --- a/tests/ut/ge/generator/ge_generator_unittest.cc +++ b/tests/ut/ge/generator/ge_generator_unittest.cc @@ -20,6 +20,7 @@ #define protected public #include "generator/ge_generator.h" #include "graph/utils/tensor_utils.h" +#include "all_ops.h" using namespace std; @@ -71,4 +72,11 @@ TEST_F(UtestGeGenerator, test_build_single_op_online) { ModelBufferData model_buffer; EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, ENGINE_AIVECTOR, model_buffer), FAILED); } + +TEST_F(UtestGeGenerator, test_set_model_name) { + GeGenerator generator; + generator.Initialize({}); + GeRootModelPtr ge_root_model = make_shared(new (std::nothrow) GeRootModel()); + EXPECT_EQ(generator.SetModelNameForDump(ge_root_model)); +} } // namespace ge