| @@ -36,7 +36,7 @@ | |||
| #include "proto/onnx/ge_onnx.pb.h" | |||
| #include "external/register/register_error_codes.h" | |||
| #include "framework/omg/parser/parser_types.h" | |||
| #include "onnx_util.h" | |||
| #include "parser/onnx/onnx_util.h" | |||
| using Status = domi::Status; | |||
| using namespace ge::parser; | |||
| @@ -16,6 +16,12 @@ | |||
| #include "common/util/error_manager/error_manager.h" | |||
| namespace ErrorMessage { | |||
| int FormatErrorMessage(char *str_dst, size_t dst_max, const char *format, ...) { | |||
| return 0; | |||
| } | |||
| } | |||
| ErrorManager &ErrorManager::GetInstance() { | |||
| static ErrorManager instance; | |||
| return instance; | |||
| @@ -254,6 +254,8 @@ set(PARSER_SRC_FILES | |||
| "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | |||
| "${PARSER_DIR}/parser/onnx/onnx_parser.cc" | |||
| "${PARSER_DIR}/parser/onnx/onnx_util.cc" | |||
| "${PARSER_DIR}/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc" | |||
| "${PARSER_DIR}/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc" | |||
| "${PARSER_DIR}/parser/tensorflow/graph_functiondef.cc" | |||
| "${PARSER_DIR}/parser/tensorflow/graph_optimizer.cc" | |||
| "${PARSER_DIR}/parser/tensorflow/iterator_fusion_pass.cc" | |||
| @@ -298,7 +300,8 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework) | |||
| set(PARSER_UT_FILES | |||
| "testcase/parser_unittest.cc" | |||
| "testcase/onnx_parser_testcase/onnx_parser_unittest.cc" | |||
| "testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc" | |||
| ) | |||
| ############ libut_parser_common.a ############ | |||
| @@ -0,0 +1,97 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <gtest/gtest.h> | |||
| #include <iostream> | |||
| #include "parser/common/op_parser_factory.h" | |||
| #include "graph/operator_reg.h" | |||
| #include "external/graph/types.h" | |||
| #include "register/op_registry.h" | |||
| #include "parser/common/register_tbe.h" | |||
| #include "external/parser/onnx_parser.h" | |||
| namespace ge { | |||
| class UtestOnnxParser : public testing::Test { | |||
| protected: | |||
| void SetUp() {} | |||
| void TearDown() {} | |||
| public: | |||
| void RegisterCustomOp(); | |||
| }; | |||
| static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | |||
| return SUCCESS; | |||
| } | |||
| void UtestOnnxParser::RegisterCustomOp() { | |||
| REGISTER_CUSTOM_OP("Conv2D") | |||
| .FrameworkType(domi::ONNX) | |||
| .OriginOpType("ai.onnx::11::Conv") | |||
| .ParseParamsFn(ParseParams); | |||
| std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | |||
| for (auto reg_data : reg_datas) { | |||
| OpRegistrationTbe::Instance()->Finalize(reg_data); | |||
| domi::OpRegistry::Instance()->Register(reg_data); | |||
| } | |||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | |||
| } | |||
| namespace { | |||
| REG_OP(Data) | |||
| .INPUT(x, TensorType::ALL()) | |||
| .OUTPUT(y, TensorType::ALL()) | |||
| .ATTR(index, Int, 0) | |||
| .OP_END_FACTORY_REG(Data) | |||
| REG_OP(Const) | |||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ | |||
| DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||
| .ATTR(value, Tensor, Tensor()) | |||
| .OP_END_FACTORY_REG(Const) | |||
| REG_OP(Conv2D) | |||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8})) | |||
| .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8})) | |||
| .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) | |||
| .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) | |||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) | |||
| .REQUIRED_ATTR(strides, ListInt) | |||
| .REQUIRED_ATTR(pads, ListInt) | |||
| .ATTR(dilations, ListInt, {1, 1, 1, 1}) | |||
| .ATTR(groups, Int, 1) | |||
| .ATTR(data_format, String, "NHWC") | |||
| .ATTR(offset_x, Int, 0) | |||
| .OP_END_FACTORY_REG(Conv2D) | |||
| } | |||
| TEST_F(UtestOnnxParser, onnx_parser_success) { | |||
| RegisterCustomOp(); | |||
| std::string case_dir = __FILE__; | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/onnx_model/conv2d.onnx"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||
| ge::Graph graph; | |||
| auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | |||
| EXPECT_EQ(ret, domi::SUCCESS); | |||
| } | |||
| } // namespace ge | |||
| @@ -1,37 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <gtest/gtest.h> | |||
| #include <iostream> | |||
| #include "parser/common/op_parser_factory.h" | |||
| namespace ge { | |||
| class UtestParser : public testing::Test { | |||
| protected: | |||
| void SetUp() {} | |||
| void TearDown() {} | |||
| }; | |||
| TEST_F(UtestParser, base) { | |||
| std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW); | |||
| EXPECT_NE(factory, nullptr); | |||
| } | |||
| } // namespace ge | |||
| @@ -0,0 +1,13 @@ | |||
| 8 | |||
| PlaceholderPlaceholder* | |||
| dtype0* | |||
| shape: | |||
| : | |||
| Placeholder_1Placeholder* | |||
| dtype0* | |||
| shape: | |||
| 6 | |||
| add_test_1AddPlaceholder Placeholder_1* | |||
| T0"† | |||
| @@ -0,0 +1,89 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <gtest/gtest.h> | |||
| #include <iostream> | |||
| #include "parser/common/op_parser_factory.h" | |||
| #include "graph/operator_reg.h" | |||
| #include "external/graph/types.h" | |||
| #include "register/op_registry.h" | |||
| #include "parser/common/register_tbe.h" | |||
| #include "external/parser/tensorflow_parser.h" | |||
| namespace ge { | |||
| class UtestTensorflowParser : public testing::Test { | |||
| protected: | |||
| void SetUp() {} | |||
| void TearDown() {} | |||
| public: | |||
| void RegisterCustomOp(); | |||
| }; | |||
| static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | |||
| return SUCCESS; | |||
| } | |||
| void UtestTensorflowParser::RegisterCustomOp() { | |||
| REGISTER_CUSTOM_OP("Add") | |||
| .FrameworkType(domi::TENSORFLOW) | |||
| .OriginOpType("Add") | |||
| .ParseParamsFn(ParseParams); | |||
| std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | |||
| for (auto reg_data : reg_datas) { | |||
| OpRegistrationTbe::Instance()->Finalize(reg_data); | |||
| domi::OpRegistry::Instance()->Register(reg_data); | |||
| } | |||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | |||
| } | |||
| namespace { | |||
| REG_OP(Data) | |||
| .INPUT(x, TensorType::ALL()) | |||
| .OUTPUT(y, TensorType::ALL()) | |||
| .ATTR(index, Int, 0) | |||
| .OP_END_FACTORY_REG(Data) | |||
| REG_OP(Add) | |||
| .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||
| DT_COMPLEX64, DT_STRING})) | |||
| .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||
| DT_COMPLEX64, DT_STRING})) | |||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||
| DT_COMPLEX64, DT_STRING})) | |||
| .OP_END_FACTORY_REG(Add) | |||
| } | |||
| TEST_F(UtestTensorflowParser, tensorflow_parser_success) { | |||
| RegisterCustomOp(); | |||
| std::string case_dir = __FILE__; | |||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||
| std::string model_file = case_dir + "/tensorflow_model/add.pb"; | |||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||
| ge::Graph graph; | |||
| auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); | |||
| EXPECT_EQ(ret, domi::SUCCESS); | |||
| } | |||
| } // namespace ge | |||