diff --git a/parser/onnx/subgraph_adapter/subgraph_adapter.h b/parser/onnx/subgraph_adapter/subgraph_adapter.h index afb2f96..b4cc051 100644 --- a/parser/onnx/subgraph_adapter/subgraph_adapter.h +++ b/parser/onnx/subgraph_adapter/subgraph_adapter.h @@ -36,7 +36,7 @@ #include "proto/onnx/ge_onnx.pb.h" #include "external/register/register_error_codes.h" #include "framework/omg/parser/parser_types.h" -#include "onnx_util.h" +#include "parser/onnx/onnx_util.h" using Status = domi::Status; using namespace ge::parser; diff --git a/tests/depends/error_manager/src/error_manager_stub.cc b/tests/depends/error_manager/src/error_manager_stub.cc index 485fdfb..0e8cb2a 100644 --- a/tests/depends/error_manager/src/error_manager_stub.cc +++ b/tests/depends/error_manager/src/error_manager_stub.cc @@ -16,6 +16,12 @@ #include "common/util/error_manager/error_manager.h" +namespace ErrorMessage { +int FormatErrorMessage(char *str_dst, size_t dst_max, const char *format, ...) { + return 0; +} +} + ErrorManager &ErrorManager::GetInstance() { static ErrorManager instance; return instance; diff --git a/tests/ut/parser/CMakeLists.txt b/tests/ut/parser/CMakeLists.txt index 15bc5ac..ddfc6f1 100644 --- a/tests/ut/parser/CMakeLists.txt +++ b/tests/ut/parser/CMakeLists.txt @@ -254,6 +254,8 @@ set(PARSER_SRC_FILES "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" "${PARSER_DIR}/parser/onnx/onnx_parser.cc" "${PARSER_DIR}/parser/onnx/onnx_util.cc" + "${PARSER_DIR}/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc" + "${PARSER_DIR}/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc" "${PARSER_DIR}/parser/tensorflow/graph_functiondef.cc" "${PARSER_DIR}/parser/tensorflow/graph_optimizer.cc" "${PARSER_DIR}/parser/tensorflow/iterator_fusion_pass.cc" @@ -298,7 +300,8 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework) set(PARSER_UT_FILES - "testcase/parser_unittest.cc" + "testcase/onnx_parser_testcase/onnx_parser_unittest.cc" + "testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc" ) ############ libut_parser_common.a ############ diff --git a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/conv2d.onnx b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/conv2d.onnx new file mode 100644 index 0000000..aa823ed Binary files /dev/null and b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_model/conv2d.onnx differ diff --git a/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc new file mode 100644 index 0000000..532c42e --- /dev/null +++ b/tests/ut/parser/testcase/onnx_parser_testcase/onnx_parser_unittest.cc @@ -0,0 +1,97 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "parser/common/op_parser_factory.h" +#include "graph/operator_reg.h" +#include "external/graph/types.h" +#include "register/op_registry.h" +#include "parser/common/register_tbe.h" +#include "external/parser/onnx_parser.h" + + +namespace ge { +class UtestOnnxParser : public testing::Test { + protected: + void SetUp() {} + + void TearDown() {} + + public: + void RegisterCustomOp(); +}; + +static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) { + return SUCCESS; +} + +void UtestOnnxParser::RegisterCustomOp() { + REGISTER_CUSTOM_OP("Conv2D") + .FrameworkType(domi::ONNX) + .OriginOpType("ai.onnx::11::Conv") + .ParseParamsFn(ParseParams); + + std::vector reg_datas = domi::OpRegistry::Instance()->registrationDatas; + for (auto reg_data : reg_datas) { + OpRegistrationTbe::Instance()->Finalize(reg_data); + domi::OpRegistry::Instance()->Register(reg_data); + } + domi::OpRegistry::Instance()->registrationDatas.clear(); +} + +namespace { +REG_OP(Data) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(index, Int, 0) + .OP_END_FACTORY_REG(Data) + +REG_OP(Const) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ + DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .ATTR(value, Tensor, Tensor()) + .OP_END_FACTORY_REG(Const) + +REG_OP(Conv2D) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8})) + .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(dilations, ListInt, {1, 1, 1, 1}) + .ATTR(groups, Int, 1) + .ATTR(data_format, String, "NHWC") + .ATTR(offset_x, Int, 0) + .OP_END_FACTORY_REG(Conv2D) +} + +TEST_F(UtestOnnxParser, onnx_parser_success) { + RegisterCustomOp(); + + std::string case_dir = __FILE__; + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/onnx_model/conv2d.onnx"; + std::map parser_params; + ge::Graph graph; + auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); + EXPECT_EQ(ret, domi::SUCCESS); +} + + +} // namespace ge \ No newline at end of file diff --git a/tests/ut/parser/testcase/parser_unittest.cc b/tests/ut/parser/testcase/parser_unittest.cc deleted file mode 100644 index 4885256..0000000 --- a/tests/ut/parser/testcase/parser_unittest.cc +++ /dev/null @@ -1,37 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "parser/common/op_parser_factory.h" - - -namespace ge { -class UtestParser : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(UtestParser, base) { - std::shared_ptr factory = OpParserFactory::Instance(domi::TENSORFLOW); - EXPECT_NE(factory, nullptr); -} - - -} // namespace ge \ No newline at end of file diff --git a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/add.pb b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/add.pb new file mode 100644 index 0000000..b9a42a3 --- /dev/null +++ b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_model/add.pb @@ -0,0 +1,13 @@ + +8 + Placeholder Placeholder* +dtype0* +shape: +: + Placeholder_1 Placeholder* +dtype0* +shape: +6 + +add_test_1Add Placeholder Placeholder_1* +T0"† \ No newline at end of file diff --git a/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc new file mode 100644 index 0000000..9a7d910 --- /dev/null +++ b/tests/ut/parser/testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc @@ -0,0 +1,89 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "parser/common/op_parser_factory.h" +#include "graph/operator_reg.h" +#include "external/graph/types.h" +#include "register/op_registry.h" +#include "parser/common/register_tbe.h" +#include "external/parser/tensorflow_parser.h" + + +namespace ge { +class UtestTensorflowParser : public testing::Test { + protected: + void SetUp() {} + + void TearDown() {} + + public: + void RegisterCustomOp(); +}; + +static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) { + return SUCCESS; +} + +void UtestTensorflowParser::RegisterCustomOp() { + REGISTER_CUSTOM_OP("Add") + .FrameworkType(domi::TENSORFLOW) + .OriginOpType("Add") + .ParseParamsFn(ParseParams); + + std::vector reg_datas = domi::OpRegistry::Instance()->registrationDatas; + for (auto reg_data : reg_datas) { + OpRegistrationTbe::Instance()->Finalize(reg_data); + domi::OpRegistry::Instance()->Register(reg_data); + } + domi::OpRegistry::Instance()->registrationDatas.clear(); +} + +namespace { +REG_OP(Data) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(index, Int, 0) + .OP_END_FACTORY_REG(Data) + +REG_OP(Add) + .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, + DT_COMPLEX64, DT_STRING})) + .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, + DT_COMPLEX64, DT_STRING})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, + DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, + DT_COMPLEX64, DT_STRING})) + .OP_END_FACTORY_REG(Add) +} + +TEST_F(UtestTensorflowParser, tensorflow_parser_success) { + RegisterCustomOp(); + + std::string case_dir = __FILE__; + case_dir = case_dir.substr(0, case_dir.find_last_of("/")); + std::string model_file = case_dir + "/tensorflow_model/add.pb"; + std::map parser_params; + ge::Graph graph; + auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); + EXPECT_EQ(ret, domi::SUCCESS); +} + + +} // namespace ge \ No newline at end of file