| @@ -36,7 +36,7 @@ | |||||
| #include "proto/onnx/ge_onnx.pb.h" | #include "proto/onnx/ge_onnx.pb.h" | ||||
| #include "external/register/register_error_codes.h" | #include "external/register/register_error_codes.h" | ||||
| #include "framework/omg/parser/parser_types.h" | #include "framework/omg/parser/parser_types.h" | ||||
| #include "onnx_util.h" | |||||
| #include "parser/onnx/onnx_util.h" | |||||
| using Status = domi::Status; | using Status = domi::Status; | ||||
| using namespace ge::parser; | using namespace ge::parser; | ||||
| @@ -16,6 +16,12 @@ | |||||
| #include "common/util/error_manager/error_manager.h" | #include "common/util/error_manager/error_manager.h" | ||||
| namespace ErrorMessage { | |||||
| int FormatErrorMessage(char *str_dst, size_t dst_max, const char *format, ...) { | |||||
| return 0; | |||||
| } | |||||
| } | |||||
| ErrorManager &ErrorManager::GetInstance() { | ErrorManager &ErrorManager::GetInstance() { | ||||
| static ErrorManager instance; | static ErrorManager instance; | ||||
| return instance; | return instance; | ||||
| @@ -254,6 +254,8 @@ set(PARSER_SRC_FILES | |||||
| "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_parser.cc" | "${PARSER_DIR}/parser/onnx/onnx_parser.cc" | ||||
| "${PARSER_DIR}/parser/onnx/onnx_util.cc" | "${PARSER_DIR}/parser/onnx/onnx_util.cc" | ||||
| "${PARSER_DIR}/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc" | |||||
| "${PARSER_DIR}/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/graph_functiondef.cc" | "${PARSER_DIR}/parser/tensorflow/graph_functiondef.cc" | ||||
| "${PARSER_DIR}/parser/tensorflow/graph_optimizer.cc" | "${PARSER_DIR}/parser/tensorflow/graph_optimizer.cc" | ||||
| "${PARSER_DIR}/parser/tensorflow/iterator_fusion_pass.cc" | "${PARSER_DIR}/parser/tensorflow/iterator_fusion_pass.cc" | ||||
| @@ -298,7 +300,8 @@ include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework) | |||||
| set(PARSER_UT_FILES | set(PARSER_UT_FILES | ||||
| "testcase/parser_unittest.cc" | |||||
| "testcase/onnx_parser_testcase/onnx_parser_unittest.cc" | |||||
| "testcase/tensorflow_parser_testcase/tensorflow_parser_unittest.cc" | |||||
| ) | ) | ||||
| ############ libut_parser_common.a ############ | ############ libut_parser_common.a ############ | ||||
| @@ -0,0 +1,97 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <iostream> | |||||
| #include "parser/common/op_parser_factory.h" | |||||
| #include "graph/operator_reg.h" | |||||
| #include "external/graph/types.h" | |||||
| #include "register/op_registry.h" | |||||
| #include "parser/common/register_tbe.h" | |||||
| #include "external/parser/onnx_parser.h" | |||||
| namespace ge { | |||||
| class UtestOnnxParser : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| public: | |||||
| void RegisterCustomOp(); | |||||
| }; | |||||
| static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | |||||
| return SUCCESS; | |||||
| } | |||||
| void UtestOnnxParser::RegisterCustomOp() { | |||||
| REGISTER_CUSTOM_OP("Conv2D") | |||||
| .FrameworkType(domi::ONNX) | |||||
| .OriginOpType("ai.onnx::11::Conv") | |||||
| .ParseParamsFn(ParseParams); | |||||
| std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | |||||
| for (auto reg_data : reg_datas) { | |||||
| OpRegistrationTbe::Instance()->Finalize(reg_data); | |||||
| domi::OpRegistry::Instance()->Register(reg_data); | |||||
| } | |||||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | |||||
| } | |||||
| namespace { | |||||
| REG_OP(Data) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(index, Int, 0) | |||||
| .OP_END_FACTORY_REG(Data) | |||||
| REG_OP(Const) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ | |||||
| DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .ATTR(value, Tensor, Tensor()) | |||||
| .OP_END_FACTORY_REG(Const) | |||||
| REG_OP(Conv2D) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8})) | |||||
| .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8})) | |||||
| .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) | |||||
| .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) | |||||
| .REQUIRED_ATTR(strides, ListInt) | |||||
| .REQUIRED_ATTR(pads, ListInt) | |||||
| .ATTR(dilations, ListInt, {1, 1, 1, 1}) | |||||
| .ATTR(groups, Int, 1) | |||||
| .ATTR(data_format, String, "NHWC") | |||||
| .ATTR(offset_x, Int, 0) | |||||
| .OP_END_FACTORY_REG(Conv2D) | |||||
| } | |||||
| TEST_F(UtestOnnxParser, onnx_parser_success) { | |||||
| RegisterCustomOp(); | |||||
| std::string case_dir = __FILE__; | |||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
| std::string model_file = case_dir + "/onnx_model/conv2d.onnx"; | |||||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||||
| ge::Graph graph; | |||||
| auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | |||||
| EXPECT_EQ(ret, domi::SUCCESS); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,37 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <iostream> | |||||
| #include "parser/common/op_parser_factory.h" | |||||
| namespace ge { | |||||
| class UtestParser : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| }; | |||||
| TEST_F(UtestParser, base) { | |||||
| std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(domi::TENSORFLOW); | |||||
| EXPECT_NE(factory, nullptr); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,13 @@ | |||||
| 8 | |||||
| PlaceholderPlaceholder* | |||||
| dtype0* | |||||
| shape: | |||||
| : | |||||
| Placeholder_1Placeholder* | |||||
| dtype0* | |||||
| shape: | |||||
| 6 | |||||
| add_test_1AddPlaceholder Placeholder_1* | |||||
| T0"† | |||||
| @@ -0,0 +1,89 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <iostream> | |||||
| #include "parser/common/op_parser_factory.h" | |||||
| #include "graph/operator_reg.h" | |||||
| #include "external/graph/types.h" | |||||
| #include "register/op_registry.h" | |||||
| #include "parser/common/register_tbe.h" | |||||
| #include "external/parser/tensorflow_parser.h" | |||||
| namespace ge { | |||||
| class UtestTensorflowParser : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| public: | |||||
| void RegisterCustomOp(); | |||||
| }; | |||||
| static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | |||||
| return SUCCESS; | |||||
| } | |||||
| void UtestTensorflowParser::RegisterCustomOp() { | |||||
| REGISTER_CUSTOM_OP("Add") | |||||
| .FrameworkType(domi::TENSORFLOW) | |||||
| .OriginOpType("Add") | |||||
| .ParseParamsFn(ParseParams); | |||||
| std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | |||||
| for (auto reg_data : reg_datas) { | |||||
| OpRegistrationTbe::Instance()->Finalize(reg_data); | |||||
| domi::OpRegistry::Instance()->Register(reg_data); | |||||
| } | |||||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | |||||
| } | |||||
| namespace { | |||||
| REG_OP(Data) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(index, Int, 0) | |||||
| .OP_END_FACTORY_REG(Data) | |||||
| REG_OP(Add) | |||||
| .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OP_END_FACTORY_REG(Add) | |||||
| } | |||||
| TEST_F(UtestTensorflowParser, tensorflow_parser_success) { | |||||
| RegisterCustomOp(); | |||||
| std::string case_dir = __FILE__; | |||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
| std::string model_file = case_dir + "/tensorflow_model/add.pb"; | |||||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||||
| ge::Graph graph; | |||||
| auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); | |||||
| EXPECT_EQ(ret, domi::SUCCESS); | |||||
| } | |||||
| } // namespace ge | |||||