Browse Source

dynamic shape inference support

pull/1239/head
lichun 5 years ago
parent
commit
45175feca9
4 changed files with 17 additions and 1 deletions
  1. +1
    -1
      ge/generator/ge_generator.cc
  2. +2
    -0
      inc/framework/generator/ge_generator.h
  3. +6
    -0
      tests/ut/ge/executor/ge_executor_unittest.cc
  4. +8
    -0
      tests/ut/ge/generator/ge_generator_unittest.cc

+ 1
- 1
ge/generator/ge_generator.cc View File

@@ -556,7 +556,7 @@ bool GeGenerator::Impl::SetOmSystemInfo(AttrHolder &obj) {
return true;
}

static Status SetModelNameForDump(GeRootModelPtr ge_root_model) {
Status GeGenerator::SetModelNameForDump(GeRootModelPtr ge_root_model) {
bool is_unknown_shape = false;
Status ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
if (ret != SUCCESS) {


+ 2
- 0
inc/framework/generator/ge_generator.h View File

@@ -29,6 +29,7 @@
#include "graph/op_desc.h"
#include "graph/detail/attributes_holder.h"
#include "omg/omg_inner_types.h"
#include "model/ge_root_model.h"

namespace ge {
class GE_FUNC_VISIBILITY GeGenerator {
@@ -98,6 +99,7 @@ class GE_FUNC_VISIBILITY GeGenerator {
const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
bool is_offline = true);
Status CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs);
Status SetModelNameForDump(GeRootModelPtr ge_root_model);

class Impl;



+ 6
- 0
tests/ut/ge/executor/ge_executor_unittest.cc View File

@@ -39,4 +39,10 @@ TEST_F(UtestGeExecutor, test_single_op_exec) {
EXPECT_EQ(exeutor.LoadSingleOp(model_name, model_data, nullptr, nullptr), ACL_ERROR_GE_INTERNAL_ERROR);
EXPECT_EQ(exeutor.LoadDynamicSingleOp(model_name, model_data, nullptr, nullptr), PARAM_INVALID);
}

TEST_F(UtestGeExecutor, test_ge_initialize) {
GeExecutor executor;
EXPECT_EQ(executor.Initialize(), SUCCESS);
EXPECT_EQ(executor.Initialize(), SUCCESS);
}
} // namespace ge

+ 8
- 0
tests/ut/ge/generator/ge_generator_unittest.cc View File

@@ -20,6 +20,7 @@
#define protected public
#include "generator/ge_generator.h"
#include "graph/utils/tensor_utils.h"
#include "all_ops.h"

using namespace std;

@@ -71,4 +72,11 @@ TEST_F(UtestGeGenerator, test_build_single_op_online) {
ModelBufferData model_buffer;
EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, ENGINE_AIVECTOR, model_buffer), FAILED);
}

TEST_F(UtestGeGenerator, test_set_model_name) {
GeGenerator generator;
generator.Initialize({});
GeRootModelPtr ge_root_model = make_shared<GeRootModelPtr>(new (std::nothrow) GeRootModel());
EXPECT_EQ(generator.SetModelNameForDump(ge_root_model));
}
} // namespace ge

Loading…
Cancel
Save