| @@ -40,7 +40,7 @@ if (ENABLE_OPEN_SRC) | |||||
| find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) | find_module(static_mmpa libmmpa.a ${GE_LIB_PATH}) | ||||
| elseif(ENABLE_GE_COV OR ENABLE_GE_UT) | elseif(ENABLE_GE_COV OR ENABLE_GE_UT) | ||||
| message(STATUS "Runing on llt mode, no need to depend other component") | message(STATUS "Runing on llt mode, no need to depend other component") | ||||
| elseif(ENABLE_PARSER_UT OR ENABLE_PARSER_COV) | |||||
| elseif(ENABLE_PARSER_UT OR ENABLE_PARSER_COV OR ENABLE_PARSER_ST) | |||||
| include(cmake/external_libs/gtest.cmake) | include(cmake/external_libs/gtest.cmake) | ||||
| add_subdirectory(tests) | add_subdirectory(tests) | ||||
| else() | else() | ||||
| @@ -151,6 +151,8 @@ build_parser() | |||||
| if [ "X$ENABLE_PARSER_UT" = "Xon" ]; then | if [ "X$ENABLE_PARSER_UT" = "Xon" ]; then | ||||
| make ut_parser -j8 | make ut_parser -j8 | ||||
| elif [ "X$ENABLE_PARSER_ST" = "Xon" ]; then | |||||
| make st_parser -j8 | |||||
| else | else | ||||
| make ${VERBOSE} -j${THREAD_NUM} && make install | make ${VERBOSE} -j${THREAD_NUM} && make install | ||||
| fi | fi | ||||
| @@ -194,6 +196,17 @@ if [[ "X$ENABLE_PARSER_UT" = "Xon" || "X$ENABLE_PARSER_COV" = "Xon" ]]; then | |||||
| genhtml coverage.info | genhtml coverage.info | ||||
| fi | fi | ||||
| if [[ "X$ENABLE_PARSER_ST" = "Xon" ]]; then | |||||
| cp ${BUILD_PATH}/tests/st/st_parser ${OUTPUT_PATH} | |||||
| RUN_TEST_CASE=${OUTPUT_PATH}/st_parser && ${RUN_TEST_CASE} | |||||
| if [[ "$?" -ne 0 ]]; then | |||||
| echo "!!! ST FAILED, PLEASE CHECK YOUR CHANGES !!!" | |||||
| echo -e "\033[31m${RUN_TEST_CASE}\033[0m" | |||||
| exit 1; | |||||
| fi | |||||
| fi | |||||
| # generate output package in tar form, including ut/st libraries/executables | # generate output package in tar form, including ut/st libraries/executables | ||||
| generate_package() | generate_package() | ||||
| { | { | ||||
| @@ -236,7 +249,7 @@ generate_package() | |||||
| tar -cf parser_lib.tar fwkacllib acllib atc | tar -cf parser_lib.tar fwkacllib acllib atc | ||||
| } | } | ||||
| if [[ "X$ENABLE_PARSER_UT" = "Xoff" ]]; then | |||||
| if [[ "X$ENABLE_PARSER_UT" = "Xoff" && "X$ENABLE_PARSER_ST" = "Xoff" ]]; then | |||||
| generate_package | generate_package | ||||
| fi | fi | ||||
| echo "---------------- Parser package archive generated ----------------" | echo "---------------- Parser package archive generated ----------------" | ||||
| @@ -50,7 +50,7 @@ int ErrorManager::ReportInterErrMessage(std::string error_code, const std::strin | |||||
| const std::string &ErrorManager::GetLogHeader() { | const std::string &ErrorManager::GetLogHeader() { | ||||
| static const std::string kLogHeader("GeUtStub"); | |||||
| static const std::string kLogHeader("[ParserStub]"); | |||||
| return kLogHeader; | return kLogHeader; | ||||
| } | } | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -15,20 +15,53 @@ | |||||
| */ | */ | ||||
| #include "toolchain/slog.h" | #include "toolchain/slog.h" | ||||
| #include "toolchain/plog.h" | |||||
| #include <stdarg.h> | #include <stdarg.h> | ||||
| #include <stdio.h> | #include <stdio.h> | ||||
| #include <string.h> | |||||
| void dav_log(int module_id, const char *fmt, ...) {} | void dav_log(int module_id, const char *fmt, ...) {} | ||||
| void DlogErrorInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } | |||||
| static int log_level = DLOG_ERROR; | |||||
| void DlogWarnInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } | |||||
| #define __DO_PRINT() \ | |||||
| do { \ | |||||
| const int FMT_BUFF_SIZE = 1024; \ | |||||
| char fmt_buff[FMT_BUFF_SIZE] = {0}; \ | |||||
| va_list valist; \ | |||||
| va_start(valist, fmt); \ | |||||
| vsnprintf(fmt_buff, FMT_BUFF_SIZE, fmt, valist); \ | |||||
| va_end(valist); \ | |||||
| printf("%s \n", fmt_buff); \ | |||||
| } while (0) | |||||
| void DlogInfoInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } | |||||
| void DlogErrorInner(int module_id, const char *fmt, ...) { | |||||
| if (log_level > DLOG_ERROR) { | |||||
| return; | |||||
| } | |||||
| __DO_PRINT(); | |||||
| } | |||||
| void DlogWarnInner(int module_id, const char *fmt, ...) { | |||||
| if (log_level > DLOG_WARN) { | |||||
| return; | |||||
| } | |||||
| __DO_PRINT(); | |||||
| } | |||||
| void DlogDebugInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } | |||||
| void DlogInfoInner(int module_id, const char *fmt, ...) { | |||||
| if (log_level > DLOG_INFO) { | |||||
| return; | |||||
| } | |||||
| __DO_PRINT(); | |||||
| } | |||||
| void DlogDebugInner(int module_id, const char *fmt, ...) { | |||||
| if (log_level > DLOG_DEBUG) { | |||||
| return; | |||||
| } | |||||
| __DO_PRINT(); | |||||
| } | |||||
| void DlogEventInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } | void DlogEventInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); } | ||||
| @@ -38,11 +71,25 @@ void DlogWithKVInner(int module_id, int level, KeyValue *pst_kv_array, int kv_nu | |||||
| dav_log(module_id, fmt); | dav_log(module_id, fmt); | ||||
| } | } | ||||
| int dlog_setlevel(int module_id, int level, int enable_event) { return DLOG_DEBUG; } | |||||
| int dlog_setlevel(int module_id, int level, int enable_event) { | |||||
| log_level = level; | |||||
| return log_level; | |||||
| } | |||||
| int dlog_getlevel(int module_id, int *enable_event) { return log_level; } | |||||
| int dlog_getlevel(int module_id, int *enable_event) { return DLOG_DEBUG; } | |||||
| int CheckLogLevel(int moduleId, int log_level_check) { return log_level >= log_level_check; } | |||||
| int CheckLogLevel(int moduleId, int logLevel) | |||||
| { | |||||
| return 1; | |||||
| } | |||||
| /** | |||||
| * @ingroup plog | |||||
| * @brief DlogReportInitialize: init log in service process before all device setting. | |||||
| * @return: 0: SUCCEED, others: FAILED | |||||
| */ | |||||
| int DlogReportInitialize() { return 0; } | |||||
| /** | |||||
| * @ingroup plog | |||||
| * @brief DlogReportFinalize: release log resource in service process after all device reset. | |||||
| * @return: 0: SUCCEED, others: FAILED | |||||
| */ | |||||
| int DlogReportFinalize() { return 0; } | |||||
| @@ -0,0 +1,356 @@ | |||||
| # Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| project(st_parser) | |||||
| set(CMAKE_CXX_STANDARD 11) | |||||
| ################################################################################ | |||||
| set(PARSER_PROTO_LIST | |||||
| "${PARSER_DIR}/metadef/proto/om.proto" | |||||
| "${PARSER_DIR}/metadef/proto/ge_ir.proto" | |||||
| "${PARSER_DIR}/metadef/proto/task.proto" | |||||
| "${PARSER_DIR}/metadef/proto/tensorflow/attr_value.proto" | |||||
| "${PARSER_DIR}/metadef/proto/tensorflow/function.proto" | |||||
| "${PARSER_DIR}/metadef/proto/tensorflow/graph.proto" | |||||
| "${PARSER_DIR}/metadef/proto/tensorflow/graph_library.proto" | |||||
| "${PARSER_DIR}/metadef/proto/tensorflow/node_def.proto" | |||||
| "${PARSER_DIR}/metadef/proto/tensorflow/op_def.proto" | |||||
| "${PARSER_DIR}/metadef/proto/tensorflow/resource_handle.proto" | |||||
| "${PARSER_DIR}/metadef/proto/tensorflow/tensor.proto" | |||||
| "${PARSER_DIR}/metadef/proto/tensorflow/tensor_shape.proto" | |||||
| "${PARSER_DIR}/metadef/proto/tensorflow/types.proto" | |||||
| "${PARSER_DIR}/metadef/proto/tensorflow/versions.proto" | |||||
| "${PARSER_DIR}/metadef/proto/caffe/caffe.proto" | |||||
| "${PARSER_DIR}/metadef/proto/onnx/ge_onnx.proto" | |||||
| #"${PARSER_DIR}/metadef/proto/proto_inner/ge_onnx.proto" | |||||
| ) | |||||
| protobuf_generate(ge PARSER_PROTO_SRCS PARSER_PROTO_HDRS ${PARSER_PROTO_LIST}) | |||||
| ############ libst_parser_proto.a ############ | |||||
| add_library(st_parser_proto STATIC | |||||
| ${PARSER_PROTO_HDRS} ${PARSER_PROTO_SRCS} | |||||
| ) | |||||
| target_compile_definitions(st_parser_proto PRIVATE | |||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||||
| google=ascend_private | |||||
| ) | |||||
| target_compile_options(st_parser_proto PRIVATE | |||||
| -O2 -g -fno-common | |||||
| ) | |||||
| target_link_libraries(st_parser_proto PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| ascend_protobuf | |||||
| ) | |||||
| ################################################################################ | |||||
| set(DUPLICATE_PROTO_LIST | |||||
| "${PARSER_DIR}/metadef/proto/proto_inner/ge_onnx.proto" | |||||
| ) | |||||
| protobuf_generate(ge DUP_PROTO_SRCS DUP_PROTO_HDRS ${DUPLICATE_PROTO_LIST}) | |||||
| ################################################################################ | |||||
| set(MATEDEF_SRC_FILES | |||||
| "${PARSER_DIR}/metadef/graph/aligned_ptr.cc" | |||||
| "${PARSER_DIR}/metadef/graph/anchor.cc" | |||||
| "${PARSER_DIR}/metadef/graph/ascend_string.cc" | |||||
| "${PARSER_DIR}/metadef/graph/attr_value.cc" | |||||
| "${PARSER_DIR}/metadef/graph/buffer.cc" | |||||
| "${PARSER_DIR}/metadef/graph/compute_graph.cc" | |||||
| "${PARSER_DIR}/metadef/graph/debug/graph_debug.cc" | |||||
| "${PARSER_DIR}/metadef/graph/detail/attributes_holder.cc" | |||||
| "${PARSER_DIR}/metadef/graph/format_refiner.cc" | |||||
| "${PARSER_DIR}/metadef/graph/ge_attr_define.cc" | |||||
| "${PARSER_DIR}/metadef/graph/ge_tensor.cc" | |||||
| "${PARSER_DIR}/metadef/graph/gnode.cc" | |||||
| "${PARSER_DIR}/metadef/graph/graph.cc" | |||||
| "${PARSER_DIR}/metadef/graph/inference_context.cc" | |||||
| "${PARSER_DIR}/metadef/graph/model.cc" | |||||
| "${PARSER_DIR}/metadef/graph/model_serialize.cc" | |||||
| "${PARSER_DIR}/metadef/graph/node.cc" | |||||
| "${PARSER_DIR}/metadef/graph/op_desc.cc" | |||||
| "${PARSER_DIR}/metadef/graph/operator.cc" | |||||
| "${PARSER_DIR}/metadef/graph/operator_factory.cc" | |||||
| "${PARSER_DIR}/metadef/graph/operator_factory_impl.cc" | |||||
| "${PARSER_DIR}/metadef/graph/opsproto/opsproto_manager.cc" | |||||
| "${PARSER_DIR}/metadef/graph/option/ge_context.cc" | |||||
| "${PARSER_DIR}/metadef/graph/option/ge_local_context.cc" | |||||
| "${PARSER_DIR}/metadef/graph/ref_relation.cc" | |||||
| "${PARSER_DIR}/metadef/graph/runtime_inference_context.cc" | |||||
| "${PARSER_DIR}/metadef/graph/shape_refiner.cc" | |||||
| "${PARSER_DIR}/metadef/graph/tensor.cc" | |||||
| "${PARSER_DIR}/metadef/graph/types.cc" | |||||
| "${PARSER_DIR}/metadef/graph/utils/anchor_utils.cc" | |||||
| "${PARSER_DIR}/metadef/graph/utils/ge_ir_utils.cc" | |||||
| "${PARSER_DIR}/metadef/graph/utils/graph_utils.cc" | |||||
| "${PARSER_DIR}/metadef/graph/utils/node_utils.cc" | |||||
| "${PARSER_DIR}/metadef/graph/utils/op_desc_utils.cc" | |||||
| "${PARSER_DIR}/metadef/graph/utils/tensor_utils.cc" | |||||
| "${PARSER_DIR}/metadef/graph/utils/transformer_utils.cc" | |||||
| "${PARSER_DIR}/metadef/graph/utils/tuning_utils.cc" | |||||
| "${PARSER_DIR}/metadef/graph/utils/type_utils.cc" | |||||
| "${PARSER_DIR}/metadef/ops/op_imp.cpp" | |||||
| "${PARSER_DIR}/metadef/third_party/transformer/src/axis_util.cc" | |||||
| "${PARSER_DIR}/metadef/third_party/transformer/src/expand_dimension.cc" | |||||
| "${PARSER_DIR}/metadef/third_party/transformer/src/transfer_shape_according_to_format.cc" | |||||
| ) | |||||
| # include directories | |||||
| include_directories(${CMAKE_CURRENT_LIST_DIR}) | |||||
| include_directories(${PARSER_DIR}/metadef/inc) | |||||
| include_directories(${PARSER_DIR}/metadef/inc/graph) | |||||
| include_directories(${PARSER_DIR}/metadef/inc/external) | |||||
| include_directories(${PARSER_DIR}/metadef/inc/external/graph) | |||||
| include_directories(${PARSER_DIR}/metadef/graph) | |||||
| include_directories(${PARSER_DIR}/metadef/third_party) | |||||
| include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc) | |||||
| include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/external) | |||||
| include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/external/ge) | |||||
| include_directories(${PARSER_DIR}/metadef/third_party/fwkacllib/inc) | |||||
| include_directories(${PARSER_DIR}/metadef/third_party/transformer/inc) | |||||
| include_directories(${PARSER_DIR}/metadef) | |||||
| include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||||
| include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto) | |||||
| ############ libst_parser_graph.a ############ | |||||
| add_library(st_parser_graph STATIC | |||||
| ${MATEDEF_SRC_FILES} ${PARSER_PROTO_HDRS} ${DUP_PROTO_HDRS} | |||||
| ) | |||||
| target_compile_definitions(st_parser_graph PRIVATE | |||||
| google=ascend_private | |||||
| ) | |||||
| target_compile_options(st_parser_graph PRIVATE | |||||
| -O2 -g -fno-common | |||||
| ) | |||||
| target_link_libraries(st_parser_graph PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| c_sec ascend_protobuf | |||||
| ) | |||||
| ################################################################################ | |||||
| set(REGISTER_SRC_FILES | |||||
| "${PARSER_DIR}/metadef/register/auto_mapping_util.cpp" | |||||
| "${PARSER_DIR}/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.cc" | |||||
| "${PARSER_DIR}/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.cc" | |||||
| "${PARSER_DIR}/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.cc" | |||||
| "${PARSER_DIR}/metadef/register/graph_optimizer/fusion_statistic/fusion_statistic_recorder.cc" | |||||
| "${PARSER_DIR}/metadef/register/graph_optimizer/graph_fusion/fusion_pass_registry.cc" | |||||
| "${PARSER_DIR}/metadef/register/graph_optimizer/graph_fusion/fusion_pattern.cc" | |||||
| "${PARSER_DIR}/metadef/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.cc" | |||||
| "${PARSER_DIR}/metadef/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass.cc" | |||||
| "${PARSER_DIR}/metadef/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.cc" | |||||
| "${PARSER_DIR}/metadef/register/host_cpu_context.cc" | |||||
| "${PARSER_DIR}/metadef/register/infer_data_slice_registry.cc" | |||||
| "${PARSER_DIR}/metadef/register/ops_kernel_builder_registry.cc" | |||||
| "${PARSER_DIR}/metadef/register/op_kernel_registry.cpp" | |||||
| "${PARSER_DIR}/metadef/register/op_tiling.cpp" | |||||
| "${PARSER_DIR}/metadef/register/op_tiling_registry.cpp" | |||||
| "${PARSER_DIR}/metadef/register/register.cpp" | |||||
| "${PARSER_DIR}/metadef/register/register_format_transfer.cc" | |||||
| "${PARSER_DIR}/metadef/register/register_pass.cpp" | |||||
| "${PARSER_DIR}/metadef/register/scope/scope_graph.cc" | |||||
| "${PARSER_DIR}/metadef/register/scope/scope_pass.cc" | |||||
| "${PARSER_DIR}/metadef/register/scope/scope_pass_registry.cc" | |||||
| "${PARSER_DIR}/metadef/register/scope/scope_pattern.cc" | |||||
| "${PARSER_DIR}/metadef/register/scope/scope_util.cc" | |||||
| "${PARSER_DIR}/metadef/register/tensor_assign.cpp" | |||||
| "${PARSER_DIR}/metadef/register/prototype_pass_registry.cc" | |||||
| ) | |||||
| # include directories | |||||
| include_directories(${CMAKE_CURRENT_LIST_DIR}) | |||||
| include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||||
| include_directories(${PARSER_DIR}/metadef) | |||||
| include_directories(${PARSER_DIR}/metadef/graph) | |||||
| include_directories(${PARSER_DIR}/metadef/inc) | |||||
| include_directories(${PARSER_DIR}/metadef/inc/external) | |||||
| include_directories(${PARSER_DIR}/metadef/inc/register) | |||||
| include_directories(${PARSER_DIR}/metadef/third_party/fwkacllib/inc) | |||||
| include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc) | |||||
| include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/external) | |||||
| include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework) | |||||
| ############ libst_parser_register.a ############ | |||||
| add_library(st_parser_register STATIC | |||||
| ${REGISTER_SRC_FILES} ${PARSER_PROTO_HDRS} | |||||
| ) | |||||
| target_compile_definitions(st_parser_register PRIVATE | |||||
| google=ascend_private | |||||
| ) | |||||
| target_compile_options(st_parser_register PRIVATE | |||||
| -O2 -g -fno-common | |||||
| ) | |||||
| target_link_libraries(st_parser_register PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| c_sec ascend_protobuf json | |||||
| ) | |||||
| ################################################################################ | |||||
| set(PARSER_SRC_FILES | |||||
| "${PARSER_DIR}/parser/caffe/caffe_custom_parser_adapter.cc" | |||||
| "${PARSER_DIR}/parser/caffe/caffe_data_parser.cc" | |||||
| "${PARSER_DIR}/parser/caffe/caffe_op_parser.cc" | |||||
| "${PARSER_DIR}/parser/caffe/caffe_parser.cc" | |||||
| "${PARSER_DIR}/parser/caffe/caffe_reshape_parser.cc" | |||||
| "${PARSER_DIR}/parser/common/acl_graph_parser_util.cc" | |||||
| "${PARSER_DIR}/parser/common/convert/pb2json.cc" | |||||
| "${PARSER_DIR}/parser/common/convert/message2operator.cc" | |||||
| "${PARSER_DIR}/parser/common/data_op_parser.cc" | |||||
| "${PARSER_DIR}/parser/common/model_saver.cc" | |||||
| "${PARSER_DIR}/parser/common/op_def/arg_op.cc" | |||||
| "${PARSER_DIR}/parser/common/op_def/constant_op.cc" | |||||
| "${PARSER_DIR}/parser/common/op_def/defs.cc" | |||||
| "${PARSER_DIR}/parser/common/op_def/fill_op.cc" | |||||
| "${PARSER_DIR}/parser/common/op_def/frameworkop_op.cc" | |||||
| "${PARSER_DIR}/parser/common/op_def/ir_pb_converter.cc" | |||||
| "${PARSER_DIR}/parser/common/op_def/no_op_op.cc" | |||||
| "${PARSER_DIR}/parser/common/op_def/operator.cc" | |||||
| "${PARSER_DIR}/parser/common/op_def/op_schema.cc" | |||||
| "${PARSER_DIR}/parser/common/op_def/ref_switch_op.cc" | |||||
| "${PARSER_DIR}/parser/common/op_def/shape_n_op.cc" | |||||
| "${PARSER_DIR}/parser/common/op_def/variable_op.cc" | |||||
| "${PARSER_DIR}/parser/common/op_def/var_is_initialized_op_op.cc" | |||||
| "${PARSER_DIR}/parser/common/op_map.cc" | |||||
| "${PARSER_DIR}/parser/common/op_parser_factory.cc" | |||||
| "${PARSER_DIR}/parser/common/parser_api.cc" | |||||
| "${PARSER_DIR}/parser/common/parser_factory.cc" | |||||
| "${PARSER_DIR}/parser/common/parser_fp16_t.cc" | |||||
| "${PARSER_DIR}/parser/common/parser_inner_ctx.cc" | |||||
| "${PARSER_DIR}/parser/common/parser_types.cc" | |||||
| "${PARSER_DIR}/parser/common/parser_utils.cc" | |||||
| "${PARSER_DIR}/parser/common/pass_manager.cc" | |||||
| "${PARSER_DIR}/parser/common/pre_checker.cc" | |||||
| "${PARSER_DIR}/parser/common/proto_file_parser.cc" | |||||
| "${PARSER_DIR}/parser/common/prototype_pass_manager.cc" | |||||
| "${PARSER_DIR}/parser/common/register_tbe.cc" | |||||
| "${PARSER_DIR}/parser/common/tbe_plugin_loader.cc" | |||||
| "${PARSER_DIR}/parser/common/thread_pool.cc" | |||||
| "${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc" | |||||
| "${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc" | |||||
| "${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc" | |||||
| "${PARSER_DIR}/parser/onnx/onnx_data_parser.cc" | |||||
| "${PARSER_DIR}/parser/onnx/onnx_parser.cc" | |||||
| "${PARSER_DIR}/parser/onnx/onnx_util.cc" | |||||
| "${PARSER_DIR}/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc" | |||||
| "${PARSER_DIR}/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/graph_functiondef.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/graph_optimizer.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/iterator_fusion_pass.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/scope/scope_pass_manager.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_arg_parser.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_constant_parser.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_custom_parser_adapter.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_data_parser.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_enter_parser.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_fill_parser.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_frameworkop_parser.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_fusionop_util.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_fusion_custom_parser_adapter.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_fusion_op_parser.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_identity_parser.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_merge_parser.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_no_op_parser.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_parser.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_ref_switch_parser.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_reshape_parser.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_shape_n_parser.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_squeeze_parser.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_util.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_variable_v2_parser.cc" | |||||
| "${PARSER_DIR}/parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc" | |||||
| ) | |||||
| # include directories | |||||
| include_directories(${CMAKE_CURRENT_LIST_DIR}) | |||||
| include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||||
| include_directories(${PARSER_DIR}) | |||||
| include_directories(${PARSER_DIR}/inc) | |||||
| include_directories(${PARSER_DIR}/parser) | |||||
| include_directories(${PARSER_DIR}/parser/onnx) | |||||
| include_directories(${PARSER_DIR}/tests) | |||||
| include_directories(${PARSER_DIR}/metadef/inc) | |||||
| include_directories(${PARSER_DIR}/metadef/inc/external) | |||||
| include_directories(${PARSER_DIR}/metadef/inc/register) | |||||
| include_directories(${PARSER_DIR}/metadef/third_party/fwkacllib/inc) | |||||
| include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc) | |||||
| include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/external) | |||||
| include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework) | |||||
| set(PARSER_ST_FILES | |||||
| "parser_st_utils.cc" | |||||
| "testcase/test_main.cc" | |||||
| "testcase/test_onnx_parser.cc" | |||||
| "testcase/test_caffe_parser.cc" | |||||
| "testcase/test_tensorflow_parser.cc" | |||||
| ) | |||||
| ############ libst_parser_common.a ############ | |||||
| add_library(st_parser_common STATIC | |||||
| ${PARSER_SRC_FILES} ${PARSER_PROTO_HDRS} | |||||
| ) | |||||
| target_compile_definitions(st_parser_common PRIVATE | |||||
| google=ascend_private | |||||
| ) | |||||
| target_compile_options(st_parser_common PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| -Werror=format | |||||
| ) | |||||
| target_link_libraries(st_parser_common PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| st_parser_proto st_parser_graph c_sec | |||||
| ascend_protobuf | |||||
| json | |||||
| ) | |||||
| ################################################################################ | |||||
| add_executable(st_parser | |||||
| ${PARSER_ST_FILES} ${PARSER_PROTO_SRCS} | |||||
| ) | |||||
| target_compile_options(st_parser PRIVATE | |||||
| -g | |||||
| ) | |||||
| target_compile_definitions(st_parser PRIVATE | |||||
| google=ascend_private | |||||
| ) | |||||
| target_link_libraries(st_parser | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| st_parser_proto | |||||
| -Wl,--whole-archive st_parser_common -Wl,--no-whole-archive | |||||
| st_parser_graph st_parser_register error_manager_stub mmpa_stub attr_util_stub | |||||
| gtest gtest_main slog_stub ascend_protobuf c_sec -lrt -ldl -lgcov | |||||
| ) | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "st/parser_st_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| namespace ge { | |||||
| void ParerSTestsUtils::ClearParserInnerCtx() { | |||||
| ge::GetParserContext().input_nodes_format_map.clear(); | |||||
| ge::GetParserContext().output_formats.clear(); | |||||
| ge::GetParserContext().user_input_dims.clear(); | |||||
| ge::GetParserContext().input_dims.clear(); | |||||
| ge::GetParserContext().op_conf_map.clear(); | |||||
| ge::GetParserContext().user_out_nodes.clear(); | |||||
| ge::GetParserContext().default_out_nodes.clear(); | |||||
| ge::GetParserContext().out_nodes_map.clear(); | |||||
| ge::GetParserContext().user_out_tensors.clear(); | |||||
| ge::GetParserContext().net_out_nodes.clear(); | |||||
| ge::GetParserContext().out_tensor_names.clear(); | |||||
| ge::GetParserContext().data_tensor_names.clear(); | |||||
| ge::GetParserContext().is_dynamic_input = false; | |||||
| ge::GetParserContext().train_flag = false; | |||||
| ge::GetParserContext().format = domi::DOMI_TENSOR_ND; | |||||
| ge::GetParserContext().type = domi::FRAMEWORK_RESERVED; | |||||
| ge::GetParserContext().run_mode = GEN_OM_MODEL; | |||||
| ge::GetParserContext().custom_proto_path = ""; | |||||
| ge::GetParserContext().caffe_proto_path = ""; | |||||
| ge::GetParserContext().enable_scope_fusion_passes = ""; | |||||
| GELOGI("Clear parser inner context successfully."); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,29 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_PARSER_TESTS_UT_PARSER_H_ | |||||
| #define GE_PARSER_TESTS_UT_PARSER_H_ | |||||
| #include "framework/omg/parser/parser_inner_ctx.h" | |||||
| namespace ge { | |||||
| class ParerSTestsUtils { | |||||
| public: | |||||
| static void ClearParserInnerCtx(); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_PARSER_TESTS_UT_PARSER_H_ | |||||
| @@ -0,0 +1,14 @@ | |||||
| name: "TestAbs" | |||||
| layer { | |||||
| name: "data" | |||||
| type: "Input" | |||||
| top: "data" | |||||
| input_param { shape: { dim: 64 dim: 1 dim: 28 dim: 28 } } | |||||
| } | |||||
| layer { | |||||
| name: "abs" | |||||
| type: "AbsVal" | |||||
| bottom: "data" | |||||
| top: "abs_out" | |||||
| } | |||||
| @@ -0,0 +1,13 @@ | |||||
| 8 | |||||
| PlaceholderPlaceholder* | |||||
| dtype0* | |||||
| shape: | |||||
| : | |||||
| Placeholder_1Placeholder* | |||||
| dtype0* | |||||
| shape: | |||||
| 6 | |||||
| add_test_1AddPlaceholder Placeholder_1* | |||||
| T0"† | |||||
| @@ -0,0 +1,102 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include "parser/common/op_parser_factory.h" | |||||
| #include "graph/operator_reg.h" | |||||
| #include "register/op_registry.h" | |||||
| #include "parser/common/register_tbe.h" | |||||
| #include "framework/omg/parser/model_parser.h" | |||||
| #include "framework/omg/parser/parser_factory.h" | |||||
| #include "external/parser/caffe_parser.h" | |||||
| #include "st/parser_st_utils.h" | |||||
| #include "external/ge/ge_api_types.h" | |||||
| namespace ge { | |||||
| class STestCaffeParser : public testing::Test { | |||||
| protected: | |||||
| void SetUp() { | |||||
| ParerSTestsUtils::ClearParserInnerCtx(); | |||||
| RegisterCustomOp(); | |||||
| } | |||||
| void TearDown() {} | |||||
| public: | |||||
| void RegisterCustomOp(); | |||||
| }; | |||||
| static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | |||||
| return SUCCESS; | |||||
| } | |||||
| void STestCaffeParser::RegisterCustomOp() { | |||||
| REGISTER_CUSTOM_OP("Data") | |||||
| .FrameworkType(domi::CAFFE) | |||||
| .OriginOpType("Input") | |||||
| .ParseParamsFn(ParseParams); | |||||
| REGISTER_CUSTOM_OP("Abs") | |||||
| .FrameworkType(domi::CAFFE) | |||||
| .OriginOpType("AbsVal") | |||||
| .ParseParamsFn(ParseParams); | |||||
| std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | |||||
| for (auto reg_data : reg_datas) { | |||||
| OpRegistrationTbe::Instance()->Finalize(reg_data); | |||||
| domi::OpRegistry::Instance()->Register(reg_data); | |||||
| } | |||||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | |||||
| } | |||||
| namespace { | |||||
| REG_OP(Data) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(index, Int, 0) | |||||
| .OP_END_FACTORY_REG(Data) | |||||
| REG_OP(Abs) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) | |||||
| .OP_END_FACTORY_REG(Abs) | |||||
| } | |||||
| TEST_F(STestCaffeParser, caffe_parser_user_output_with_default) { | |||||
| std::string case_dir = __FILE__; | |||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
| std::string model_file = case_dir + "/origin_models/caffe_abs.pbtxt"; | |||||
| auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::CAFFE); | |||||
| ASSERT_NE(model_parser, nullptr); | |||||
| ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmp_graph"); | |||||
| ASSERT_NE(compute_graph, nullptr); | |||||
| ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||||
| auto ret = model_parser->Parse(model_file.c_str(), graph); | |||||
| ASSERT_EQ(ret, GRAPH_SUCCESS); | |||||
| AclGrphParseUtil acl_graph_parse_util; | |||||
| std::map<AscendString, AscendString> parser_params; | |||||
| auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params); | |||||
| ASSERT_EQ(status, SUCCESS); | |||||
| auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); | |||||
| ASSERT_EQ(output_nodes_info.size(), 1); | |||||
| EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "abs"); | |||||
| EXPECT_EQ((output_nodes_info.at(0).second), 0); | |||||
| auto &net_out_name = ge::GetParserContext().net_out_nodes; | |||||
| ASSERT_EQ(net_out_name.size(), 1); | |||||
| EXPECT_EQ(net_out_name.at(0), "abs:0:abs_out"); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,27 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <iostream> | |||||
| #include <gtest/gtest.h> | |||||
| using namespace std; | |||||
| int main(int argc, char **argv) { | |||||
| testing::InitGoogleTest(&argc, argv); | |||||
| int ret = RUN_ALL_TESTS(); | |||||
| std::cout << "Finish parser st." << std::endl; | |||||
| return ret; | |||||
| } | |||||
| @@ -0,0 +1,185 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <iostream> | |||||
| #include "parser/common/op_parser_factory.h" | |||||
| #include "graph/operator_reg.h" | |||||
| #include "register/op_registry.h" | |||||
| #include "parser/common/register_tbe.h" | |||||
| #include "external/parser/onnx_parser.h" | |||||
| #include "st/parser_st_utils.h" | |||||
| #include "external/ge/ge_api_types.h" | |||||
| namespace ge { | |||||
| class STestOnnxParser : public testing::Test { | |||||
| protected: | |||||
| void SetUp() { | |||||
| ParerSTestsUtils::ClearParserInnerCtx(); | |||||
| RegisterCustomOp(); | |||||
| } | |||||
| void TearDown() {} | |||||
| public: | |||||
| void RegisterCustomOp(); | |||||
| }; | |||||
| static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | |||||
| return SUCCESS; | |||||
| } | |||||
| static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_dest) { | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ParseSubgraphPostFnIf(const std::string& subgraph_name, const ge::Graph& graph) { | |||||
| domi::AutoMappingSubgraphIOIndexFunc auto_mapping_subgraph_index_func = | |||||
| domi::FrameworkRegistry::Instance().GetAutoMappingSubgraphIOIndexFunc(domi::ONNX); | |||||
| if (auto_mapping_subgraph_index_func == nullptr) { | |||||
| std::cout<<"auto mapping if subgraph func is nullptr!"<<std::endl; | |||||
| return FAILED; | |||||
| } | |||||
| return auto_mapping_subgraph_index_func(graph, | |||||
| [&](int data_index, int &parent_index) -> Status { | |||||
| parent_index = data_index + 1; | |||||
| return SUCCESS; | |||||
| }, | |||||
| [&](int output_index, int &parent_index) -> Status { | |||||
| parent_index = output_index; | |||||
| return SUCCESS; | |||||
| }); | |||||
| } | |||||
| void STestOnnxParser::RegisterCustomOp() { | |||||
| REGISTER_CUSTOM_OP("Conv2D") | |||||
| .FrameworkType(domi::ONNX) | |||||
| .OriginOpType("ai.onnx::11::Conv") | |||||
| .ParseParamsFn(ParseParams); | |||||
| // register if op info to GE | |||||
| REGISTER_CUSTOM_OP("If") | |||||
| .FrameworkType(domi::ONNX) | |||||
| .OriginOpType({"ai.onnx::9::If", | |||||
| "ai.onnx::10::If", | |||||
| "ai.onnx::11::If", | |||||
| "ai.onnx::12::If", | |||||
| "ai.onnx::13::If"}) | |||||
| .ParseParamsFn(ParseParams) | |||||
| .ParseParamsByOperatorFn(ParseParamByOpFunc) | |||||
| .ParseSubgraphPostFn(ParseSubgraphPostFnIf); | |||||
| REGISTER_CUSTOM_OP("Add") | |||||
| .FrameworkType(domi::ONNX) | |||||
| .OriginOpType("ai.onnx::11::Add") | |||||
| .ParseParamsFn(ParseParams); | |||||
| REGISTER_CUSTOM_OP("Identity") | |||||
| .FrameworkType(domi::ONNX) | |||||
| .OriginOpType("ai.onnx::11::Identity") | |||||
| .ParseParamsFn(ParseParams); | |||||
| std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | |||||
| for (auto reg_data : reg_datas) { | |||||
| OpRegistrationTbe::Instance()->Finalize(reg_data); | |||||
| domi::OpRegistry::Instance()->Register(reg_data); | |||||
| } | |||||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | |||||
| } | |||||
| namespace { | |||||
| REG_OP(Data) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(index, Int, 0) | |||||
| .OP_END_FACTORY_REG(Data) | |||||
| REG_OP(Const) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \ | |||||
| DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .ATTR(value, Tensor, Tensor()) | |||||
| .OP_END_FACTORY_REG(Const) | |||||
| REG_OP(Conv2D) | |||||
| .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8})) | |||||
| .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8})) | |||||
| .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) | |||||
| .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) | |||||
| .REQUIRED_ATTR(strides, ListInt) | |||||
| .REQUIRED_ATTR(pads, ListInt) | |||||
| .ATTR(dilations, ListInt, {1, 1, 1, 1}) | |||||
| .ATTR(groups, Int, 1) | |||||
| .ATTR(data_format, String, "NHWC") | |||||
| .ATTR(offset_x, Int, 0) | |||||
| .OP_END_FACTORY_REG(Conv2D) | |||||
| REG_OP(If) | |||||
| .INPUT(cond, TensorType::ALL()) | |||||
| .DYNAMIC_INPUT(input, TensorType::ALL()) | |||||
| .DYNAMIC_OUTPUT(output, TensorType::ALL()) | |||||
| .GRAPH(then_branch) | |||||
| .GRAPH(else_branch) | |||||
| .OP_END_FACTORY_REG(If) | |||||
| REG_OP(Add) | |||||
| .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OP_END_FACTORY_REG(Add) | |||||
| REG_OP(Identity) | |||||
| .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, | |||||
| DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) | |||||
| .OP_END_FACTORY_REG(Identity) | |||||
| } | |||||
| TEST_F(STestOnnxParser, onnx_parser_user_output_with_default) { | |||||
| std::string case_dir = __FILE__; | |||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
| std::string model_file = case_dir + "/origin_models/onnx_conv2d.onnx"; | |||||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||||
| ge::Graph graph; | |||||
| auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | |||||
| ASSERT_EQ(ret, GRAPH_SUCCESS); | |||||
| ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||||
| auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); | |||||
| ASSERT_EQ(output_nodes_info.size(), 1); | |||||
| EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "Conv_0"); | |||||
| EXPECT_EQ((output_nodes_info.at(0).second), 0); | |||||
| auto &net_out_name = ge::GetParserContext().net_out_nodes; | |||||
| ASSERT_EQ(net_out_name.size(), 1); | |||||
| EXPECT_EQ(net_out_name.at(0), "Conv_0:0:y"); | |||||
| } | |||||
| TEST_F(STestOnnxParser, onnx_parser_if_node) { | |||||
| std::string case_dir = __FILE__; | |||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
| std::string model_file = case_dir + "/origin_models/onnx_if.onnx"; | |||||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||||
| ge::Graph graph; | |||||
| auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph); | |||||
| EXPECT_EQ(ret, GRAPH_SUCCESS); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,96 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include "parser/common/op_parser_factory.h" | |||||
| #include "parser/tensorflow/tensorflow_parser.h" | |||||
| #include "graph/operator_reg.h" | |||||
| #include "register/op_registry.h" | |||||
| #include "parser/common/register_tbe.h" | |||||
| #include "external/parser/tensorflow_parser.h" | |||||
| #include "st/parser_st_utils.h" | |||||
| namespace ge { | |||||
| class STestTensorflowParser : public testing::Test { | |||||
| protected: | |||||
| void SetUp() { | |||||
| ParerSTestsUtils::ClearParserInnerCtx(); | |||||
| } | |||||
| void TearDown() {} | |||||
| public: | |||||
| void RegisterCustomOp(); | |||||
| }; | |||||
| static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) { | |||||
| return SUCCESS; | |||||
| } | |||||
| void STestTensorflowParser::RegisterCustomOp() { | |||||
| REGISTER_CUSTOM_OP("Add") | |||||
| .FrameworkType(domi::TENSORFLOW) | |||||
| .OriginOpType("Add") | |||||
| .ParseParamsFn(ParseParams); | |||||
| std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas; | |||||
| for (auto reg_data : reg_datas) { | |||||
| OpRegistrationTbe::Instance()->Finalize(reg_data); | |||||
| domi::OpRegistry::Instance()->Register(reg_data); | |||||
| } | |||||
| domi::OpRegistry::Instance()->registrationDatas.clear(); | |||||
| } | |||||
| namespace { | |||||
| REG_OP(Data) | |||||
| .INPUT(x, TensorType::ALL()) | |||||
| .OUTPUT(y, TensorType::ALL()) | |||||
| .ATTR(index, Int, 0) | |||||
| .OP_END_FACTORY_REG(Data) | |||||
| REG_OP(Add) | |||||
| .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, | |||||
| DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, | |||||
| DT_COMPLEX64, DT_STRING})) | |||||
| .OP_END_FACTORY_REG(Add) | |||||
| } | |||||
| TEST_F(STestTensorflowParser, tensorflow_parser_success) { | |||||
| RegisterCustomOp(); | |||||
| std::string case_dir = __FILE__; | |||||
| case_dir = case_dir.substr(0, case_dir.find_last_of("/")); | |||||
| std::string model_file = case_dir + "/origin_models/tf_add.pb"; | |||||
| std::map<ge::AscendString, ge::AscendString> parser_params; | |||||
| ge::Graph graph; | |||||
| auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph); | |||||
| ASSERT_EQ(ret, SUCCESS); | |||||
| ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||||
| auto output_nodes_info = compute_graph->GetGraphOutNodesInfo(); | |||||
| ASSERT_EQ(output_nodes_info.size(), 1); | |||||
| EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "add_test_1"); | |||||
| EXPECT_EQ((output_nodes_info.at(0).second), 0); | |||||
| auto &net_out_name = ge::GetParserContext().net_out_nodes; | |||||
| ASSERT_EQ(net_out_name.size(), 1); | |||||
| EXPECT_EQ(net_out_name.at(0), "add_test_1:0"); | |||||
| } | |||||
| } // namespace ge | |||||