Browse Source

add parser st

pull/351/head
wangzhengjun 4 years ago
parent
commit
5ae547df25
15 changed files with 941 additions and 15 deletions
  1. +1
    -1
      CMakeLists.txt
  2. +14
    -1
      build.sh
  3. +1
    -1
      tests/depends/error_manager/src/error_manager_stub.cc
  4. +59
    -12
      tests/depends/slog/src/slog_stub.cc
  5. +356
    -0
      tests/st/CMakeLists.txt
  6. +44
    -0
      tests/st/parser_st_utils.cc
  7. +29
    -0
      tests/st/parser_st_utils.h
  8. +14
    -0
      tests/st/testcase/origin_models/caffe_abs.pbtxt
  9. BIN
      tests/st/testcase/origin_models/onnx_conv2d.onnx
  10. BIN
      tests/st/testcase/origin_models/onnx_if.onnx
  11. +13
    -0
      tests/st/testcase/origin_models/tf_add.pb
  12. +102
    -0
      tests/st/testcase/test_caffe_parser.cc
  13. +27
    -0
      tests/st/testcase/test_main.cc
  14. +185
    -0
      tests/st/testcase/test_onnx_parser.cc
  15. +96
    -0
      tests/st/testcase/test_tensorflow_parser.cc

+ 1
- 1
CMakeLists.txt View File

@@ -40,7 +40,7 @@ if (ENABLE_OPEN_SRC)
find_module(static_mmpa libmmpa.a ${GE_LIB_PATH})
elseif(ENABLE_GE_COV OR ENABLE_GE_UT)
message(STATUS "Runing on llt mode, no need to depend other component")
elseif(ENABLE_PARSER_UT OR ENABLE_PARSER_COV)
elseif(ENABLE_PARSER_UT OR ENABLE_PARSER_COV OR ENABLE_PARSER_ST)
include(cmake/external_libs/gtest.cmake)
add_subdirectory(tests)
else()


+ 14
- 1
build.sh View File

@@ -151,6 +151,8 @@ build_parser()

if [ "X$ENABLE_PARSER_UT" = "Xon" ]; then
make ut_parser -j8
elif [ "X$ENABLE_PARSER_ST" = "Xon" ]; then
make st_parser -j8
else
make ${VERBOSE} -j${THREAD_NUM} && make install
fi
@@ -194,6 +196,17 @@ if [[ "X$ENABLE_PARSER_UT" = "Xon" || "X$ENABLE_PARSER_COV" = "Xon" ]]; then
genhtml coverage.info
fi

if [[ "X$ENABLE_PARSER_ST" = "Xon" ]]; then
cp ${BUILD_PATH}/tests/st/st_parser ${OUTPUT_PATH}

RUN_TEST_CASE=${OUTPUT_PATH}/st_parser && ${RUN_TEST_CASE}
if [[ "$?" -ne 0 ]]; then
echo "!!! ST FAILED, PLEASE CHECK YOUR CHANGES !!!"
echo -e "\033[31m${RUN_TEST_CASE}\033[0m"
exit 1;
fi
fi

# generate output package in tar form, including ut/st libraries/executables
generate_package()
{
@@ -236,7 +249,7 @@ generate_package()
tar -cf parser_lib.tar fwkacllib acllib atc
}

if [[ "X$ENABLE_PARSER_UT" = "Xoff" ]]; then
if [[ "X$ENABLE_PARSER_UT" = "Xoff" && "X$ENABLE_PARSER_ST" = "Xoff" ]]; then
generate_package
fi
echo "---------------- Parser package archive generated ----------------"

+ 1
- 1
tests/depends/error_manager/src/error_manager_stub.cc View File

@@ -50,7 +50,7 @@ int ErrorManager::ReportInterErrMessage(std::string error_code, const std::strin


const std::string &ErrorManager::GetLogHeader() {
static const std::string kLogHeader("GeUtStub");
static const std::string kLogHeader("[ParserStub]");
return kLogHeader;
}



+ 59
- 12
tests/depends/slog/src/slog_stub.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -15,20 +15,53 @@
*/

#include "toolchain/slog.h"
#include "toolchain/plog.h"

#include <stdarg.h>
#include <stdio.h>
#include <string.h>

void dav_log(int module_id, const char *fmt, ...) {}

void DlogErrorInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); }
static int log_level = DLOG_ERROR;

void DlogWarnInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); }
#define __DO_PRINT() \
do { \
const int FMT_BUFF_SIZE = 1024; \
char fmt_buff[FMT_BUFF_SIZE] = {0}; \
va_list valist; \
va_start(valist, fmt); \
vsnprintf(fmt_buff, FMT_BUFF_SIZE, fmt, valist); \
va_end(valist); \
printf("%s \n", fmt_buff); \
} while (0)

void DlogInfoInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); }
void DlogErrorInner(int module_id, const char *fmt, ...) {
if (log_level > DLOG_ERROR) {
return;
}
__DO_PRINT();
}

void DlogWarnInner(int module_id, const char *fmt, ...) {
if (log_level > DLOG_WARN) {
return;
}
__DO_PRINT();
}

void DlogDebugInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); }
void DlogInfoInner(int module_id, const char *fmt, ...) {
if (log_level > DLOG_INFO) {
return;
}
__DO_PRINT();
}

void DlogDebugInner(int module_id, const char *fmt, ...) {
if (log_level > DLOG_DEBUG) {
return;
}
__DO_PRINT();
}

void DlogEventInner(int module_id, const char *fmt, ...) { dav_log(module_id, fmt); }

@@ -38,11 +71,25 @@ void DlogWithKVInner(int module_id, int level, KeyValue *pst_kv_array, int kv_nu
dav_log(module_id, fmt);
}

int dlog_setlevel(int module_id, int level, int enable_event) { return DLOG_DEBUG; }
int dlog_setlevel(int module_id, int level, int enable_event) {
log_level = level;
return log_level;
}

int dlog_getlevel(int module_id, int *enable_event) { return log_level; }

int dlog_getlevel(int module_id, int *enable_event) { return DLOG_DEBUG; }
int CheckLogLevel(int moduleId, int log_level_check) { return log_level >= log_level_check; }

int CheckLogLevel(int moduleId, int logLevel)
{
return 1;
}
/**
* @ingroup plog
* @brief DlogReportInitialize: init log in service process before all device setting.
* @return: 0: SUCCEED, others: FAILED
*/
int DlogReportInitialize() { return 0; }

/**
* @ingroup plog
* @brief DlogReportFinalize: release log resource in service process after all device reset.
* @return: 0: SUCCEED, others: FAILED
*/
int DlogReportFinalize() { return 0; }

+ 356
- 0
tests/st/CMakeLists.txt View File

@@ -0,0 +1,356 @@
# Copyright 2019-2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

project(st_parser)

set(CMAKE_CXX_STANDARD 11)

################################################################################
set(PARSER_PROTO_LIST
"${PARSER_DIR}/metadef/proto/om.proto"
"${PARSER_DIR}/metadef/proto/ge_ir.proto"
"${PARSER_DIR}/metadef/proto/task.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/attr_value.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/function.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/graph.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/graph_library.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/node_def.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/op_def.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/resource_handle.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/tensor.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/tensor_shape.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/types.proto"
"${PARSER_DIR}/metadef/proto/tensorflow/versions.proto"
"${PARSER_DIR}/metadef/proto/caffe/caffe.proto"
"${PARSER_DIR}/metadef/proto/onnx/ge_onnx.proto"
#"${PARSER_DIR}/metadef/proto/proto_inner/ge_onnx.proto"
)

protobuf_generate(ge PARSER_PROTO_SRCS PARSER_PROTO_HDRS ${PARSER_PROTO_LIST})

############ libst_parser_proto.a ############
add_library(st_parser_proto STATIC
${PARSER_PROTO_HDRS} ${PARSER_PROTO_SRCS}
)

target_compile_definitions(st_parser_proto PRIVATE
PROTOBUF_INLINE_NOT_IN_HEADERS=0
google=ascend_private
)

target_compile_options(st_parser_proto PRIVATE
-O2 -g -fno-common
)

target_link_libraries(st_parser_proto PRIVATE
$<BUILD_INTERFACE:intf_pub>
ascend_protobuf
)


################################################################################
set(DUPLICATE_PROTO_LIST
"${PARSER_DIR}/metadef/proto/proto_inner/ge_onnx.proto"
)

protobuf_generate(ge DUP_PROTO_SRCS DUP_PROTO_HDRS ${DUPLICATE_PROTO_LIST})

################################################################################
set(MATEDEF_SRC_FILES
"${PARSER_DIR}/metadef/graph/aligned_ptr.cc"
"${PARSER_DIR}/metadef/graph/anchor.cc"
"${PARSER_DIR}/metadef/graph/ascend_string.cc"
"${PARSER_DIR}/metadef/graph/attr_value.cc"
"${PARSER_DIR}/metadef/graph/buffer.cc"
"${PARSER_DIR}/metadef/graph/compute_graph.cc"
"${PARSER_DIR}/metadef/graph/debug/graph_debug.cc"
"${PARSER_DIR}/metadef/graph/detail/attributes_holder.cc"
"${PARSER_DIR}/metadef/graph/format_refiner.cc"
"${PARSER_DIR}/metadef/graph/ge_attr_define.cc"
"${PARSER_DIR}/metadef/graph/ge_tensor.cc"
"${PARSER_DIR}/metadef/graph/gnode.cc"
"${PARSER_DIR}/metadef/graph/graph.cc"
"${PARSER_DIR}/metadef/graph/inference_context.cc"
"${PARSER_DIR}/metadef/graph/model.cc"
"${PARSER_DIR}/metadef/graph/model_serialize.cc"
"${PARSER_DIR}/metadef/graph/node.cc"
"${PARSER_DIR}/metadef/graph/op_desc.cc"
"${PARSER_DIR}/metadef/graph/operator.cc"
"${PARSER_DIR}/metadef/graph/operator_factory.cc"
"${PARSER_DIR}/metadef/graph/operator_factory_impl.cc"
"${PARSER_DIR}/metadef/graph/opsproto/opsproto_manager.cc"
"${PARSER_DIR}/metadef/graph/option/ge_context.cc"
"${PARSER_DIR}/metadef/graph/option/ge_local_context.cc"
"${PARSER_DIR}/metadef/graph/ref_relation.cc"
"${PARSER_DIR}/metadef/graph/runtime_inference_context.cc"
"${PARSER_DIR}/metadef/graph/shape_refiner.cc"
"${PARSER_DIR}/metadef/graph/tensor.cc"
"${PARSER_DIR}/metadef/graph/types.cc"
"${PARSER_DIR}/metadef/graph/utils/anchor_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/ge_ir_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/graph_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/node_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/op_desc_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/tensor_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/transformer_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/tuning_utils.cc"
"${PARSER_DIR}/metadef/graph/utils/type_utils.cc"
"${PARSER_DIR}/metadef/ops/op_imp.cpp"
"${PARSER_DIR}/metadef/third_party/transformer/src/axis_util.cc"
"${PARSER_DIR}/metadef/third_party/transformer/src/expand_dimension.cc"
"${PARSER_DIR}/metadef/third_party/transformer/src/transfer_shape_according_to_format.cc"
)

# include directories
include_directories(${CMAKE_CURRENT_LIST_DIR})
include_directories(${PARSER_DIR}/metadef/inc)
include_directories(${PARSER_DIR}/metadef/inc/graph)
include_directories(${PARSER_DIR}/metadef/inc/external)
include_directories(${PARSER_DIR}/metadef/inc/external/graph)
include_directories(${PARSER_DIR}/metadef/graph)
include_directories(${PARSER_DIR}/metadef/third_party)
include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc)
include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/external)
include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/external/ge)
include_directories(${PARSER_DIR}/metadef/third_party/fwkacllib/inc)
include_directories(${PARSER_DIR}/metadef/third_party/transformer/inc)
include_directories(${PARSER_DIR}/metadef)
include_directories(${CMAKE_BINARY_DIR}/proto/ge)
include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto)

############ libst_parser_graph.a ############
add_library(st_parser_graph STATIC
${MATEDEF_SRC_FILES} ${PARSER_PROTO_HDRS} ${DUP_PROTO_HDRS}
)

target_compile_definitions(st_parser_graph PRIVATE
google=ascend_private
)

target_compile_options(st_parser_graph PRIVATE
-O2 -g -fno-common
)

target_link_libraries(st_parser_graph PRIVATE
$<BUILD_INTERFACE:intf_pub>
c_sec ascend_protobuf
)


################################################################################
set(REGISTER_SRC_FILES
"${PARSER_DIR}/metadef/register/auto_mapping_util.cpp"
"${PARSER_DIR}/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.cc"
"${PARSER_DIR}/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.cc"
"${PARSER_DIR}/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.cc"
"${PARSER_DIR}/metadef/register/graph_optimizer/fusion_statistic/fusion_statistic_recorder.cc"
"${PARSER_DIR}/metadef/register/graph_optimizer/graph_fusion/fusion_pass_registry.cc"
"${PARSER_DIR}/metadef/register/graph_optimizer/graph_fusion/fusion_pattern.cc"
"${PARSER_DIR}/metadef/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.cc"
"${PARSER_DIR}/metadef/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass.cc"
"${PARSER_DIR}/metadef/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.cc"
"${PARSER_DIR}/metadef/register/host_cpu_context.cc"
"${PARSER_DIR}/metadef/register/infer_data_slice_registry.cc"
"${PARSER_DIR}/metadef/register/ops_kernel_builder_registry.cc"
"${PARSER_DIR}/metadef/register/op_kernel_registry.cpp"
"${PARSER_DIR}/metadef/register/op_tiling.cpp"
"${PARSER_DIR}/metadef/register/op_tiling_registry.cpp"
"${PARSER_DIR}/metadef/register/register.cpp"
"${PARSER_DIR}/metadef/register/register_format_transfer.cc"
"${PARSER_DIR}/metadef/register/register_pass.cpp"
"${PARSER_DIR}/metadef/register/scope/scope_graph.cc"
"${PARSER_DIR}/metadef/register/scope/scope_pass.cc"
"${PARSER_DIR}/metadef/register/scope/scope_pass_registry.cc"
"${PARSER_DIR}/metadef/register/scope/scope_pattern.cc"
"${PARSER_DIR}/metadef/register/scope/scope_util.cc"
"${PARSER_DIR}/metadef/register/tensor_assign.cpp"
"${PARSER_DIR}/metadef/register/prototype_pass_registry.cc"
)

# include directories
include_directories(${CMAKE_CURRENT_LIST_DIR})
include_directories(${CMAKE_BINARY_DIR}/proto/ge)
include_directories(${PARSER_DIR}/metadef)
include_directories(${PARSER_DIR}/metadef/graph)
include_directories(${PARSER_DIR}/metadef/inc)
include_directories(${PARSER_DIR}/metadef/inc/external)
include_directories(${PARSER_DIR}/metadef/inc/register)
include_directories(${PARSER_DIR}/metadef/third_party/fwkacllib/inc)
include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc)
include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/external)
include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework)

############ libst_parser_register.a ############
add_library(st_parser_register STATIC
${REGISTER_SRC_FILES} ${PARSER_PROTO_HDRS}
)

target_compile_definitions(st_parser_register PRIVATE
google=ascend_private
)

target_compile_options(st_parser_register PRIVATE
-O2 -g -fno-common
)

target_link_libraries(st_parser_register PRIVATE
$<BUILD_INTERFACE:intf_pub>
c_sec ascend_protobuf json
)


################################################################################
set(PARSER_SRC_FILES
"${PARSER_DIR}/parser/caffe/caffe_custom_parser_adapter.cc"
"${PARSER_DIR}/parser/caffe/caffe_data_parser.cc"
"${PARSER_DIR}/parser/caffe/caffe_op_parser.cc"
"${PARSER_DIR}/parser/caffe/caffe_parser.cc"
"${PARSER_DIR}/parser/caffe/caffe_reshape_parser.cc"
"${PARSER_DIR}/parser/common/acl_graph_parser_util.cc"
"${PARSER_DIR}/parser/common/convert/pb2json.cc"
"${PARSER_DIR}/parser/common/convert/message2operator.cc"
"${PARSER_DIR}/parser/common/data_op_parser.cc"
"${PARSER_DIR}/parser/common/model_saver.cc"
"${PARSER_DIR}/parser/common/op_def/arg_op.cc"
"${PARSER_DIR}/parser/common/op_def/constant_op.cc"
"${PARSER_DIR}/parser/common/op_def/defs.cc"
"${PARSER_DIR}/parser/common/op_def/fill_op.cc"
"${PARSER_DIR}/parser/common/op_def/frameworkop_op.cc"
"${PARSER_DIR}/parser/common/op_def/ir_pb_converter.cc"
"${PARSER_DIR}/parser/common/op_def/no_op_op.cc"
"${PARSER_DIR}/parser/common/op_def/operator.cc"
"${PARSER_DIR}/parser/common/op_def/op_schema.cc"
"${PARSER_DIR}/parser/common/op_def/ref_switch_op.cc"
"${PARSER_DIR}/parser/common/op_def/shape_n_op.cc"
"${PARSER_DIR}/parser/common/op_def/variable_op.cc"
"${PARSER_DIR}/parser/common/op_def/var_is_initialized_op_op.cc"
"${PARSER_DIR}/parser/common/op_map.cc"
"${PARSER_DIR}/parser/common/op_parser_factory.cc"
"${PARSER_DIR}/parser/common/parser_api.cc"
"${PARSER_DIR}/parser/common/parser_factory.cc"
"${PARSER_DIR}/parser/common/parser_fp16_t.cc"
"${PARSER_DIR}/parser/common/parser_inner_ctx.cc"
"${PARSER_DIR}/parser/common/parser_types.cc"
"${PARSER_DIR}/parser/common/parser_utils.cc"
"${PARSER_DIR}/parser/common/pass_manager.cc"
"${PARSER_DIR}/parser/common/pre_checker.cc"
"${PARSER_DIR}/parser/common/proto_file_parser.cc"
"${PARSER_DIR}/parser/common/prototype_pass_manager.cc"
"${PARSER_DIR}/parser/common/register_tbe.cc"
"${PARSER_DIR}/parser/common/tbe_plugin_loader.cc"
"${PARSER_DIR}/parser/common/thread_pool.cc"
"${PARSER_DIR}/parser/common/auto_mapping_subgraph_io_index_func.cc"
"${PARSER_DIR}/parser/onnx/onnx_constant_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_custom_parser_adapter.cc"
"${PARSER_DIR}/parser/onnx/onnx_data_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_parser.cc"
"${PARSER_DIR}/parser/onnx/onnx_util.cc"
"${PARSER_DIR}/parser/onnx/subgraph_adapter/if_subgraph_adapter.cc"
"${PARSER_DIR}/parser/onnx/subgraph_adapter/subgraph_adapter_factory.cc"
"${PARSER_DIR}/parser/tensorflow/graph_functiondef.cc"
"${PARSER_DIR}/parser/tensorflow/graph_optimizer.cc"
"${PARSER_DIR}/parser/tensorflow/iterator_fusion_pass.cc"
"${PARSER_DIR}/parser/tensorflow/scope/scope_pass_manager.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_arg_parser.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_auto_mapping_parser_adapter.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_constant_parser.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_custom_parser_adapter.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_data_parser.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_enter_parser.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_fill_parser.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_frameworkop_parser.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_fusionop_util.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_fusion_custom_parser_adapter.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_fusion_op_parser.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_identity_parser.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_merge_parser.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_no_op_parser.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_parser.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_ref_switch_parser.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_reshape_parser.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_shape_n_parser.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_squeeze_parser.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_util.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_variable_v2_parser.cc"
"${PARSER_DIR}/parser/tensorflow/tensorflow_var_is_initialized_op_parser.cc"
)

# include directories
include_directories(${CMAKE_CURRENT_LIST_DIR})
include_directories(${CMAKE_BINARY_DIR}/proto/ge)
include_directories(${PARSER_DIR})
include_directories(${PARSER_DIR}/inc)
include_directories(${PARSER_DIR}/parser)
include_directories(${PARSER_DIR}/parser/onnx)
include_directories(${PARSER_DIR}/tests)
include_directories(${PARSER_DIR}/metadef/inc)
include_directories(${PARSER_DIR}/metadef/inc/external)
include_directories(${PARSER_DIR}/metadef/inc/register)
include_directories(${PARSER_DIR}/metadef/third_party/fwkacllib/inc)
include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc)
include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/external)
include_directories(${PARSER_DIR}/metadef/third_party/graphengine/inc/framework)


set(PARSER_ST_FILES
"parser_st_utils.cc"
"testcase/test_main.cc"
"testcase/test_onnx_parser.cc"
"testcase/test_caffe_parser.cc"
"testcase/test_tensorflow_parser.cc"
)

############ libst_parser_common.a ############
add_library(st_parser_common STATIC
${PARSER_SRC_FILES} ${PARSER_PROTO_HDRS}
)

target_compile_definitions(st_parser_common PRIVATE
google=ascend_private
)

target_compile_options(st_parser_common PRIVATE
-g --coverage -fprofile-arcs -ftest-coverage
-Werror=format
)

target_link_libraries(st_parser_common PRIVATE
$<BUILD_INTERFACE:intf_pub>
st_parser_proto st_parser_graph c_sec
ascend_protobuf
json
)


################################################################################
add_executable(st_parser
${PARSER_ST_FILES} ${PARSER_PROTO_SRCS}
)

target_compile_options(st_parser PRIVATE
-g
)

target_compile_definitions(st_parser PRIVATE
google=ascend_private
)

target_link_libraries(st_parser
$<BUILD_INTERFACE:intf_pub>
st_parser_proto
-Wl,--whole-archive st_parser_common -Wl,--no-whole-archive
st_parser_graph st_parser_register error_manager_stub mmpa_stub attr_util_stub
gtest gtest_main slog_stub ascend_protobuf c_sec -lrt -ldl -lgcov
)

+ 44
- 0
tests/st/parser_st_utils.cc View File

@@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "st/parser_st_utils.h"
#include "framework/common/debug/ge_log.h"

namespace ge {
void ParerSTestsUtils::ClearParserInnerCtx() {
ge::GetParserContext().input_nodes_format_map.clear();
ge::GetParserContext().output_formats.clear();
ge::GetParserContext().user_input_dims.clear();
ge::GetParserContext().input_dims.clear();
ge::GetParserContext().op_conf_map.clear();
ge::GetParserContext().user_out_nodes.clear();
ge::GetParserContext().default_out_nodes.clear();
ge::GetParserContext().out_nodes_map.clear();
ge::GetParserContext().user_out_tensors.clear();
ge::GetParserContext().net_out_nodes.clear();
ge::GetParserContext().out_tensor_names.clear();
ge::GetParserContext().data_tensor_names.clear();
ge::GetParserContext().is_dynamic_input = false;
ge::GetParserContext().train_flag = false;
ge::GetParserContext().format = domi::DOMI_TENSOR_ND;
ge::GetParserContext().type = domi::FRAMEWORK_RESERVED;
ge::GetParserContext().run_mode = GEN_OM_MODEL;
ge::GetParserContext().custom_proto_path = "";
ge::GetParserContext().caffe_proto_path = "";
ge::GetParserContext().enable_scope_fusion_passes = "";
GELOGI("Clear parser inner context successfully.");
}
} // namespace ge

+ 29
- 0
tests/st/parser_st_utils.h View File

@@ -0,0 +1,29 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_PARSER_TESTS_UT_PARSER_H_
#define GE_PARSER_TESTS_UT_PARSER_H_

#include "framework/omg/parser/parser_inner_ctx.h"

namespace ge {
class ParerSTestsUtils {
public:
static void ClearParserInnerCtx();
};
} // namespace ge

#endif // GE_PARSER_TESTS_UT_PARSER_H_

+ 14
- 0
tests/st/testcase/origin_models/caffe_abs.pbtxt View File

@@ -0,0 +1,14 @@
name: "TestAbs"
layer {
name: "data"
type: "Input"
top: "data"
input_param { shape: { dim: 64 dim: 1 dim: 28 dim: 28 } }
}
layer {
name: "abs"
type: "AbsVal"
bottom: "data"
top: "abs_out"
}

BIN
tests/st/testcase/origin_models/onnx_conv2d.onnx View File


BIN
tests/st/testcase/origin_models/onnx_if.onnx View File


+ 13
- 0
tests/st/testcase/origin_models/tf_add.pb View File

@@ -0,0 +1,13 @@

8
Placeholder Placeholder*
dtype0*
shape:
:
Placeholder_1 Placeholder*
dtype0*
shape:
6

add_test_1Add Placeholder Placeholder_1*
T0"†

+ 102
- 0
tests/st/testcase/test_caffe_parser.cc View File

@@ -0,0 +1,102 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include "parser/common/op_parser_factory.h"
#include "graph/operator_reg.h"
#include "register/op_registry.h"
#include "parser/common/register_tbe.h"
#include "framework/omg/parser/model_parser.h"
#include "framework/omg/parser/parser_factory.h"
#include "external/parser/caffe_parser.h"
#include "st/parser_st_utils.h"
#include "external/ge/ge_api_types.h"

namespace ge {
class STestCaffeParser : public testing::Test {
protected:
void SetUp() {
ParerSTestsUtils::ClearParserInnerCtx();
RegisterCustomOp();
}

void TearDown() {}

public:
void RegisterCustomOp();
};

static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) {
return SUCCESS;
}
void STestCaffeParser::RegisterCustomOp() {
REGISTER_CUSTOM_OP("Data")
.FrameworkType(domi::CAFFE)
.OriginOpType("Input")
.ParseParamsFn(ParseParams);

REGISTER_CUSTOM_OP("Abs")
.FrameworkType(domi::CAFFE)
.OriginOpType("AbsVal")
.ParseParamsFn(ParseParams);

std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data);
}
domi::OpRegistry::Instance()->registrationDatas.clear();
}

namespace {
REG_OP(Data)
.INPUT(x, TensorType::ALL())
.OUTPUT(y, TensorType::ALL())
.ATTR(index, Int, 0)
.OP_END_FACTORY_REG(Data)

REG_OP(Abs)
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64}))
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64}))
.OP_END_FACTORY_REG(Abs)
}

TEST_F(STestCaffeParser, caffe_parser_user_output_with_default) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/caffe_abs.pbtxt";
auto model_parser = domi::ModelParserFactory::Instance()->CreateModelParser(domi::CAFFE);
ASSERT_NE(model_parser, nullptr);
ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmp_graph");
ASSERT_NE(compute_graph, nullptr);
ge::Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph);
auto ret = model_parser->Parse(model_file.c_str(), graph);
ASSERT_EQ(ret, GRAPH_SUCCESS);
AclGrphParseUtil acl_graph_parse_util;
std::map<AscendString, AscendString> parser_params;
auto status = acl_graph_parse_util.SetOutputNodeInfo(graph, parser_params);
ASSERT_EQ(status, SUCCESS);

auto output_nodes_info = compute_graph->GetGraphOutNodesInfo();
ASSERT_EQ(output_nodes_info.size(), 1);
EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "abs");
EXPECT_EQ((output_nodes_info.at(0).second), 0);
auto &net_out_name = ge::GetParserContext().net_out_nodes;
ASSERT_EQ(net_out_name.size(), 1);
EXPECT_EQ(net_out_name.at(0), "abs:0:abs_out");
}

} // namespace ge

+ 27
- 0
tests/st/testcase/test_main.cc View File

@@ -0,0 +1,27 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <iostream>
#include <gtest/gtest.h>

using namespace std;

int main(int argc, char **argv) {
testing::InitGoogleTest(&argc, argv);
int ret = RUN_ALL_TESTS();
std::cout << "Finish parser st." << std::endl;
return ret;
}

+ 185
- 0
tests/st/testcase/test_onnx_parser.cc View File

@@ -0,0 +1,185 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <iostream>
#include "parser/common/op_parser_factory.h"
#include "graph/operator_reg.h"
#include "register/op_registry.h"
#include "parser/common/register_tbe.h"
#include "external/parser/onnx_parser.h"
#include "st/parser_st_utils.h"
#include "external/ge/ge_api_types.h"

namespace ge {
class STestOnnxParser : public testing::Test {
protected:
void SetUp() {
ParerSTestsUtils::ClearParserInnerCtx();
RegisterCustomOp();
}

void TearDown() {}

public:
void RegisterCustomOp();
};

static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) {
return SUCCESS;
}

static Status ParseParamByOpFunc(const ge::Operator &op_src, ge::Operator& op_dest) {
return SUCCESS;
}

Status ParseSubgraphPostFnIf(const std::string& subgraph_name, const ge::Graph& graph) {
domi::AutoMappingSubgraphIOIndexFunc auto_mapping_subgraph_index_func =
domi::FrameworkRegistry::Instance().GetAutoMappingSubgraphIOIndexFunc(domi::ONNX);
if (auto_mapping_subgraph_index_func == nullptr) {
std::cout<<"auto mapping if subgraph func is nullptr!"<<std::endl;
return FAILED;
}
return auto_mapping_subgraph_index_func(graph,
[&](int data_index, int &parent_index) -> Status {
parent_index = data_index + 1;
return SUCCESS;
},
[&](int output_index, int &parent_index) -> Status {
parent_index = output_index;
return SUCCESS;
});
}

void STestOnnxParser::RegisterCustomOp() {
REGISTER_CUSTOM_OP("Conv2D")
.FrameworkType(domi::ONNX)
.OriginOpType("ai.onnx::11::Conv")
.ParseParamsFn(ParseParams);

// register if op info to GE
REGISTER_CUSTOM_OP("If")
.FrameworkType(domi::ONNX)
.OriginOpType({"ai.onnx::9::If",
"ai.onnx::10::If",
"ai.onnx::11::If",
"ai.onnx::12::If",
"ai.onnx::13::If"})
.ParseParamsFn(ParseParams)
.ParseParamsByOperatorFn(ParseParamByOpFunc)
.ParseSubgraphPostFn(ParseSubgraphPostFnIf);

REGISTER_CUSTOM_OP("Add")
.FrameworkType(domi::ONNX)
.OriginOpType("ai.onnx::11::Add")
.ParseParamsFn(ParseParams);

REGISTER_CUSTOM_OP("Identity")
.FrameworkType(domi::ONNX)
.OriginOpType("ai.onnx::11::Identity")
.ParseParamsFn(ParseParams);

std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data);
}
domi::OpRegistry::Instance()->registrationDatas.clear();
}

namespace {
REG_OP(Data)
.INPUT(x, TensorType::ALL())
.OUTPUT(y, TensorType::ALL())
.ATTR(index, Int, 0)
.OP_END_FACTORY_REG(Data)

REG_OP(Const)
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, \
DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.ATTR(value, Tensor, Tensor())
.OP_END_FACTORY_REG(Const)

REG_OP(Conv2D)
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8}))
.INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8}))
.OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
.OPTIONAL_INPUT(offset_w, TensorType({DT_INT8}))
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32}))
.REQUIRED_ATTR(strides, ListInt)
.REQUIRED_ATTR(pads, ListInt)
.ATTR(dilations, ListInt, {1, 1, 1, 1})
.ATTR(groups, Int, 1)
.ATTR(data_format, String, "NHWC")
.ATTR(offset_x, Int, 0)
.OP_END_FACTORY_REG(Conv2D)

REG_OP(If)
.INPUT(cond, TensorType::ALL())
.DYNAMIC_INPUT(input, TensorType::ALL())
.DYNAMIC_OUTPUT(output, TensorType::ALL())
.GRAPH(then_branch)
.GRAPH(else_branch)
.OP_END_FACTORY_REG(If)

REG_OP(Add)
.INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16,
DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
DT_COMPLEX64, DT_STRING}))
.INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16,
DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
DT_COMPLEX64, DT_STRING}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16,
DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
DT_COMPLEX64, DT_STRING}))
.OP_END_FACTORY_REG(Add)

REG_OP(Identity)
.INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8,
DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE}))
.OP_END_FACTORY_REG(Identity)
}

TEST_F(STestOnnxParser, onnx_parser_user_output_with_default) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/onnx_conv2d.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
ASSERT_EQ(ret, GRAPH_SUCCESS);
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
auto output_nodes_info = compute_graph->GetGraphOutNodesInfo();
ASSERT_EQ(output_nodes_info.size(), 1);
EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "Conv_0");
EXPECT_EQ((output_nodes_info.at(0).second), 0);
auto &net_out_name = ge::GetParserContext().net_out_nodes;
ASSERT_EQ(net_out_name.size(), 1);
EXPECT_EQ(net_out_name.at(0), "Conv_0:0:y");
}

TEST_F(STestOnnxParser, onnx_parser_if_node) {
std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/onnx_if.onnx";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto ret = ge::aclgrphParseONNX(model_file.c_str(), parser_params, graph);
EXPECT_EQ(ret, GRAPH_SUCCESS);
}
} // namespace ge

+ 96
- 0
tests/st/testcase/test_tensorflow_parser.cc View File

@@ -0,0 +1,96 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include "parser/common/op_parser_factory.h"
#include "parser/tensorflow/tensorflow_parser.h"
#include "graph/operator_reg.h"
#include "register/op_registry.h"
#include "parser/common/register_tbe.h"
#include "external/parser/tensorflow_parser.h"
#include "st/parser_st_utils.h"

namespace ge {
class STestTensorflowParser : public testing::Test {
protected:
void SetUp() {
ParerSTestsUtils::ClearParserInnerCtx();
}

void TearDown() {}

public:
void RegisterCustomOp();
};

static Status ParseParams(const google::protobuf::Message* op_src, ge::Operator& op_dest) {
return SUCCESS;
}

void STestTensorflowParser::RegisterCustomOp() {
REGISTER_CUSTOM_OP("Add")
.FrameworkType(domi::TENSORFLOW)
.OriginOpType("Add")
.ParseParamsFn(ParseParams);

std::vector<OpRegistrationData> reg_datas = domi::OpRegistry::Instance()->registrationDatas;
for (auto reg_data : reg_datas) {
OpRegistrationTbe::Instance()->Finalize(reg_data);
domi::OpRegistry::Instance()->Register(reg_data);
}
domi::OpRegistry::Instance()->registrationDatas.clear();
}

namespace {
REG_OP(Data)
.INPUT(x, TensorType::ALL())
.OUTPUT(y, TensorType::ALL())
.ATTR(index, Int, 0)
.OP_END_FACTORY_REG(Data)

REG_OP(Add)
.INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16,
DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
DT_COMPLEX64, DT_STRING}))
.INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16,
DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
DT_COMPLEX64, DT_STRING}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16,
DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128,
DT_COMPLEX64, DT_STRING}))
.OP_END_FACTORY_REG(Add)
}

TEST_F(STestTensorflowParser, tensorflow_parser_success) {
RegisterCustomOp();

std::string case_dir = __FILE__;
case_dir = case_dir.substr(0, case_dir.find_last_of("/"));
std::string model_file = case_dir + "/origin_models/tf_add.pb";
std::map<ge::AscendString, ge::AscendString> parser_params;
ge::Graph graph;
auto ret = ge::aclgrphParseTensorFlow(model_file.c_str(), parser_params, graph);
ASSERT_EQ(ret, SUCCESS);
ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph);
auto output_nodes_info = compute_graph->GetGraphOutNodesInfo();
ASSERT_EQ(output_nodes_info.size(), 1);
EXPECT_EQ((output_nodes_info.at(0).first->GetName()), "add_test_1");
EXPECT_EQ((output_nodes_info.at(0).second), 0);
auto &net_out_name = ge::GetParserContext().net_out_nodes;
ASSERT_EQ(net_out_name.size(), 1);
EXPECT_EQ(net_out_name.at(0), "add_test_1:0");
}
} // namespace ge

Loading…
Cancel
Save