Browse Source

parset ut

pull/288/head
y00500818 4 years ago
parent
commit
3308a864f3
8 changed files with 210 additions and 39 deletions
  1. +1
    -1
      parser/onnx/subgraph_adapter/subgraph_adapter.h
  2. +6
    -0
      tests/depends/error_manager/src/error_manager_stub.cc
  3. +4
    -1
      tests/ut/parser/CMakeLists.txt
  4. BIN
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/conv2d.onnx
  5. +97
    -0
      tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc
  6. +0
    -37
      tests/ut/parser/testcase/parser_unittest.cc
  7. +13
    -0
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/add.pb
  8. +89
    -0
      tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc

+ 1
- 1
parser/onnx/subgraph_adapter/subgraph_adapter.h View File

@@ -36,7 +36,7 @@
#include "proto/onnx/ge_onnx.pb.h"
#include "external/register/register_error_codes.h"
#include "framework/omg/parser/parser_types.h"
#include "onnx_util.h"
#include "parser/onnx/onnx_util.h"

using Status = domi::Status;
using namespace ge::parser;


+ 6
- 0
tests/depends/error_manager/src/error_manager_stub.cc View File

@@ -16,6 +16,12 @@

#include "common/util/error_manager/error_manager.h"

namespace ErrorMessage {
int FormatErrorMessage(char *str_dst, size_t dst_max, const char *format, ...) {
return 0;
}
}

ErrorManager &ErrorManager::GetInstance() {
static ErrorManager instance;
return instance;


+ 4
- 1
tests/ut/parser/CMakeLists.txt View File

@@ -254,6 +254,8 @@ set(PARSER_SRC_FILES
"${PARSER_DIR}/parser/onnx/onnx_data_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_util.cc"
"${PARSER_DIR}/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc"
"${PARSER_DIR}/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc"
"${PARSER_DIR}/parser/tensorflow/graph_functiondef.cc"
"${PARSER_DIR}/parser/tensorflow/graph_optimizer.cc"
"${PARSER_DIR}/parser/tensorflow/iterator_fusion_pass.cc"
@@ -298,7 +300,8 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework)


set(PARSER_UT_FILES
"testcase/parser_unittest.cc"
"testcase/onnx_parser_testcase/onnx_parser_unittest.cc"
"testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc"
)

############ libut_parser_common.a ############


BIN
tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/conv2d.onnx View File


+ 97
- 0
tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc View File

@@ -0,0 +1,97 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <iostream>
#include "parser/common/op_parser_factory.h"
#include "graph/operator_reg.h"
#include "external/graph/types.h"
#include "register/op_registry.h"
#include "parser/common/register_tbe.h"
#include "external/parser/onnx_parser.h"


namespace ge {
class UtestOnnxParser : public testing::Test {
protected:
void SetUp() {}

void TearDown() {}

public:
void RegisterCustomOp();
};

static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) {
return SUCCESS;
}

void UtestOnnxParser::RegisterCustomOp() {
REGISTER_CUSTOM_OP("Conv2D")
.FrameworkType(domi::ONNX)
.OriginOpType("ai.onnx::11::Conv")
.ParseParamsFn(ParseParams);

std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data);
}
domi::OpRegistry::Instance()->registrationDatas.clear();
}

namespace {
REG_OP(Data)
.INPUT(x, TensorType::ALL())
.OUTPUT(y, TensorType::ALL())
.ATTR(index, Int, 0)
.OP_END_FACTORY_REG(Data)

REG_OP(Const)
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \
DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.ATTR(value, Tensor, Tensor())
.OP_END_FACTORY_REG(Const)

REG_OP(Conv2D)
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8}))
.INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8}))
.OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
.OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
.REQUIRED_ATTR(strides, ListInt)
.REQUIRED_ATTR(pads, ListInt)
.ATTR(dilations, ListInt, {1, 1, 1, 1})
.ATTR(groups, Int, 1)
.ATTR(data_format, String, "NHWC")
.ATTR(offset_x, Int, 0)
.OP_END_FACTORY_REG(Conv2D)
}

TEST_F(UtestOnnxParser, onnx_parser_success) {
RegisterCustomOp();

std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/onnx_model/conv2d.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, domi::SUCCESS);
}


} // namespace ge

+ 0
- 37
tests/ut/parser/testcase/parser_unittest.cc View File

@@ -1,37 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <iostream>

#include "parser/common/op_parser_factory.h"


namespace ge {
class UtestParser : public testing::Test {
protected:
void SetUp() {}

void TearDown() {}
};

TEST_F(UtestParser, base) {
std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW);
EXPECT_NE(factory, nullptr);
}


} // namespace ge

+ 13
- 0
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/add.pb View File

@@ -0,0 +1,13 @@

8
Placeholder Placeholder*
dtype0*
shape:
:
Placeholder_1 Placeholder*
dtype0*
shape:
6

add_test_1Add Placeholder Placeholder_1*
T0"†

+ 89
- 0
tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc View File

@@ -0,0 +1,89 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <iostream>
#include "parser/common/op_parser_factory.h"
#include "graph/operator_reg.h"
#include "external/graph/types.h"
#include "register/op_registry.h"
#include "parser/common/register_tbe.h"
#include "external/parser/tensorflow_parser.h"


namespace ge {
class UtestTensorflowParser : public testing::Test {
protected:
void SetUp() {}

void TearDown() {}

public:
void RegisterCustomOp();
};

static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) {
return SUCCESS;
}

void UtestTensorflowParser::RegisterCustomOp() {
REGISTER_CUSTOM_OP("Add")
.FrameworkType(domi::TENSORFLOW)
.OriginOpType("Add")
.ParseParamsFn(ParseParams);

std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data);
}
domi::OpRegistry::Instance()->registrationDatas.clear();
}

namespace {
REG_OP(Data)
.INPUT(x, TensorType::ALL())
.OUTPUT(y, TensorType::ALL())
.ATTR(index, Int, 0)
.OP_END_FACTORY_REG(Data)

REG_OP(Add)
.INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16,
DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
DT_COMPLEX64, DT_STRING}))
.INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16,
DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
DT_COMPLEX64, DT_STRING}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16,
DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
DT_COMPLEX64, DT_STRING}))
.OP_END_FACTORY_REG(Add)
}

TEST_F(UtestTensorflowParser, tensorflow_parser_success) {
RegisterCustomOp();

std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/tensorflow_model/add.pb";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, domi::SUCCESS);
}


} // namespace ge

Loading…
Cancel
Save