| Author | SHA1 | Message | Date |
|---|---|---|---|
|
|
a879159cd9 |
!181 parser one to many add original name
Merge pull request !181 from yangyongqiang/master |
5 years ago |
|
|
f5b9b56976 | parser one to many add original name | 5 years ago |
|
|
2458a4fb10 |
!168 mod add control edge log
Merge pull request !168 from 王笑天/master |
5 years ago |
|
|
921c0f4a4a | mod add control edge log print | 5 years ago |
|
|
6ef2a32651 |
!159 remove specific process when swtich control edge to other node
Merge pull request !159 from 王笑天/master |
5 years ago |
|
|
4bf949e8f4 | remove specific process when swtich control edge to other node | 5 years ago |
|
|
c1d95bac3a |
!145 restore code
Merge pull request !145 from yangyongqiang/master |
5 years ago |
|
|
118afdc74f | restore code | 5 years ago |
|
|
befc2aac08 |
!141 add validation of fmk type for plugin load.
Merge pull request !141 from yangyongqiang/master |
5 years ago |
|
|
ed68fe4968 | add validation of fmk type for plugin load. | 5 years ago |
|
|
b5d7c0d3da |
!140 add validation of fmk type for plugin load.
Merge pull request !140 from yangyongqiang/master |
5 years ago |
|
|
1196f14a49 | add validation of fmk type for plugin load. | 5 years ago |
|
|
c841458262 |
!136 ONLY_COMPILE_OPEN_SRC_METADEF
Merge pull request !136 from yangyongqiang/master |
5 years ago |
|
|
e3607a9bab | ONLY_COMPILE_OPEN_SRC_METADEF | 5 years ago |
|
|
13f2b268d9 |
!132 remove parser ONLY_COMPILE_OPEN_SRC
Merge pull request !132 from yangyongqiang/master |
5 years ago |
|
|
62ba68da76 | remove parser ONLY_COMPILE_OPEN_SRC | 5 years ago |
|
|
3380d801ae |
!127 parser one to many
Merge pull request !127 from yangyongqiang/master |
5 years ago |
|
|
1a2a156d7d | parser one to many | 5 years ago |
|
|
5dd26bcc0c |
!124 parser one to many
Merge pull request !124 from yangyongqiang/master |
5 years ago |
|
|
ecfa6f1a12 | parser one to many | 5 years ago |
|
|
58d657a959 |
!122 update submodule metadef
Merge pull request !122 from yangyongqiang/master |
5 years ago |
|
|
11be91b6f5 | update submodule metadef | 5 years ago |
|
|
76a862b1bc | update README.md. | 5 years ago |
|
|
4dc4e4c051 |
!68 solve product side
Merge pull request !68 from taoxudong/master |
5 years ago |
|
|
7c25d2c3be | solve product side | 5 years ago |
|
|
45cee6f977 |
!65 add pb2json
Merge pull request !65 from taoxiangdong/master |
5 years ago |
|
|
deea2f1d34 | add convert pb2josn | 5 years ago |
|
|
87defa07f8 |
!64 update roadmap
Merge pull request !64 from 王涛/master |
5 years ago |
|
|
7cd0089118 | update README.md. | 5 years ago |
|
|
99e435028c |
!52 parser depend on ge in master branch
Merge pull request !52 from taoxiangdong/master |
5 years ago |
|
|
496594e1a3 | update parser denpend on ge | 5 years ago |
|
|
b2ce41d5bb |
!50 update cmake compile dependency
Merge pull request !50 from taoxiangdong/master |
5 years ago |
|
|
4555f2ded5 | update submodule metadef | 5 years ago |
|
|
18d6a45abe | update cmake compile dependency | 5 years ago |
|
|
5c097bb666 |
!48 delete parser common convert
Merge pull request !48 from taoxiangdong/master |
5 years ago |
|
|
112797a552 | update parser common convert | 5 years ago |
|
|
c6b1f992db |
!46 update submoduel metadef
Merge pull request !46 from taoxiangdong/master |
5 years ago |
|
|
ca6227490e | update submodule metadef | 5 years ago |
|
|
f60e66c153 |
!45 update metadef submoudle
Merge pull request !45 from taoxiangdong/master |
5 years ago |
|
|
048cd7588f | update submodule metadef | 5 years ago |
|
|
b97123e9d5 | update master parser src code part2 | 5 years ago |
|
|
6c37290fef | update README.md. | 5 years ago |
|
|
cc57aded17 | update README.md. | 5 years ago |
|
|
27e4a3c31e | update README.md. | 5 years ago |
|
|
f3d7bd35d4 | update README.md. | 5 years ago |
|
|
cec81adf13 | update submodule metadef | 5 years ago |
|
|
5fa9c1276e |
!43 update src code frome yellow zone
Merge pull request !43 from taoxiangdong/master |
5 years ago |
|
|
dcfe59ef7a | update master src code from yellow zone | 5 years ago |
|
|
fee27781d6 | Merge remote-tracking branch 'upstream/development' | 5 years ago |
|
|
5ea9437bbe |
!31 update json cmake
Merge pull request !31 from taoxiangdong/master |
5 years ago |
|
|
be653a1703 |
!29 update cmake
Merge pull request !29 from taoxiangdong/master |
5 years ago |
|
|
c53d788e91 |
!27 update submodule
Merge pull request !27 from taoxiangdong/master |
5 years ago |
|
|
763e9d285d |
!25 updtae submodule metadef
Merge pull request !25 from taoxiangdong/master |
5 years ago |
|
|
191091a3dd |
!23 sync from yellow zone 20201020
Merge pull request !23 from taoxiangdong/master |
5 years ago |
|
|
88eb8af4ae |
!21 update git submodule metadef
Merge pull request !21 from taoxiangdong/master |
5 years ago |
|
|
27f3757531 |
!19 remove compile cache
Merge pull request !19 from taoxiangdong/master |
5 years ago |
|
|
b7702a4fa4 |
!17 add metadef submodule
Merge pull request !17 from taoxiangdong/master |
5 years ago |
|
|
85cd977a30 |
!15 update atc cmake
Merge pull request !15 from taoxiangdong/master |
5 years ago |
|
|
240e4efb44 |
!13 add build.sh
Merge pull request !13 from taoxiangdong/master |
5 years ago |
|
|
996bf32d3e |
!11 update cmake
Merge pull request !11 from taoxiangdong/master |
5 years ago |
|
|
9a7eb21ea5 |
!9 update cmakelists
Merge pull request !9 from taoxiangdong/master |
5 years ago |
|
|
7e69a9b628 |
!7 update readme for parser
Merge pull request !7 from 王正俊/master |
5 years ago |
| @@ -1,4 +1,4 @@ | |||||
| [submodule "metadef"] | [submodule "metadef"] | ||||
| path = metadef | path = metadef | ||||
| url = https://gitee.com/ascend/metadef.git | url = https://gitee.com/ascend/metadef.git | ||||
| branch = development | |||||
| branch = master | |||||
| @@ -9,7 +9,7 @@ parser以动态库的方式被调用。 | |||||
| ### 源码安装 | ### 源码安装 | ||||
| 进行源码编译前,确保系统满足以下要求: | |||||
| Parser支持由源码编译,进行源码编译前,首先确保你有昇腾910 AI处理器的环境进行源码编译前,确保系统满足以下要求: | |||||
| - GCC >= 7.3.0 | - GCC >= 7.3.0 | ||||
| - CMake >= 3.14.0 | - CMake >= 3.14.0 | ||||
| @@ -22,19 +22,41 @@ parser以动态库的方式被调用。 | |||||
| ``` | ``` | ||||
| git clone https://gitee.com/ascend/parser.git | git clone https://gitee.com/ascend/parser.git | ||||
| cd parser | cd parser | ||||
| git submodule init && git submodule update | |||||
| ``` | ``` | ||||
| #### 源码编译 | #### 源码编译 | ||||
| 在parser根目录执行以下命令编译: | |||||
| ``` | ``` | ||||
| // 正在补充 | |||||
| 目前parser需要集成到mindspore/graphengine中使用,暂不支持独立编译,解耦独立编译正在开发中,敬请期待; | |||||
| ``` | ``` | ||||
| ## 贡献 | ## 贡献 | ||||
| 欢迎参与贡献。 | 欢迎参与贡献。 | ||||
| ## 路标 | |||||
| 以下将展示graphenine/parser近期的计划,我们会根据用户的反馈诉求,持续调整计划的优先级。 | |||||
| 总体而言,我们会努力在以下几个方面不断改进。 | |||||
| 1、完备性:Cast/ConcatV2算子支持输入数据类型为int64的常量折叠; | |||||
| 2、完备性:onnx parser支持一对多映射; | |||||
| 3、架构优化:ATC解耦并迁移至parser; | |||||
| 4、易用性:提供tensorflow训练的checkpoint文件转pb文件的一键式转化工具; | |||||
| 5、易用性:提供一键式本地编译环境构建工具; | |||||
| 6、可维测:ATC转换生成的om模型包含框架信息、cann版本信息和芯片信息等; | |||||
| 热忱希望各位在用户社区加入讨论,并贡献您的建议。 | |||||
| ## Release Notes | ## Release Notes | ||||
| Release Notes请参考[RELEASE](RELEASE.md)。 | Release Notes请参考[RELEASE](RELEASE.md)。 | ||||
| @@ -23,7 +23,7 @@ export BUILD_PATH="${BASEPATH}/build/" | |||||
| usage() | usage() | ||||
| { | { | ||||
| echo "Usage:" | echo "Usage:" | ||||
| echo "sh build.sh [-j[n]] [-h] [-v] [-s] [-t] [-u] [-c]" | |||||
| echo "sh build.sh [-j[n]] [-h] [-v] [-s] [-t] [-u] [-c] [-S on|off]" | |||||
| echo "" | echo "" | ||||
| echo "Options:" | echo "Options:" | ||||
| echo " -h Print usage" | echo " -h Print usage" | ||||
| @@ -33,9 +33,21 @@ usage() | |||||
| echo " -t Build and execute ut" | echo " -t Build and execute ut" | ||||
| echo " -c Build ut with coverage tag" | echo " -c Build ut with coverage tag" | ||||
| echo " -v Display build command" | echo " -v Display build command" | ||||
| echo " -S Enable enable download cmake compile dependency from gitee , default off" | |||||
| echo "to be continued ..." | echo "to be continued ..." | ||||
| } | } | ||||
| # check value of input is 'on' or 'off' | |||||
| # usage: check_on_off arg_value arg_name | |||||
| check_on_off() | |||||
| { | |||||
| if [[ "X$1" != "Xon" && "X$1" != "Xoff" ]]; then | |||||
| echo "Invalid value $1 for option -$2" | |||||
| usage | |||||
| exit 1 | |||||
| fi | |||||
| } | |||||
| # parse and set options | # parse and set options | ||||
| checkopts() | checkopts() | ||||
| { | { | ||||
| @@ -46,8 +58,9 @@ checkopts() | |||||
| ENABLE_GE_ST="off" | ENABLE_GE_ST="off" | ||||
| ENABLE_GE_COV="off" | ENABLE_GE_COV="off" | ||||
| GE_ONLY="on" | GE_ONLY="on" | ||||
| ENABLE_GITEE="off" | |||||
| # Process the options | # Process the options | ||||
| while getopts 'ustchj:v' opt | |||||
| while getopts 'ustchj:vS:' opt | |||||
| do | do | ||||
| OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') | OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') | ||||
| case "${opt}" in | case "${opt}" in | ||||
| @@ -77,6 +90,11 @@ checkopts() | |||||
| v) | v) | ||||
| VERBOSE="VERBOSE=1" | VERBOSE="VERBOSE=1" | ||||
| ;; | ;; | ||||
| S) | |||||
| check_on_off $OPTARG S | |||||
| ENABLE_GITEE="$OPTARG" | |||||
| echo "enable download from gitee" | |||||
| ;; | |||||
| *) | *) | ||||
| echo "Undefined option: ${opt}" | echo "Undefined option: ${opt}" | ||||
| usage | usage | ||||
| @@ -119,6 +137,10 @@ build_parser() | |||||
| CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_ST=ON" | CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GE_ST=ON" | ||||
| fi | fi | ||||
| if [[ "X$ENABLE_GITEE" = "Xon" ]]; then | |||||
| CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_GITEE=ON" | |||||
| fi | |||||
| CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_OPEN_SRC=True -DCMAKE_INSTALL_PREFIX=${OUTPUT_PATH}" | CMAKE_ARGS="${CMAKE_ARGS} -DENABLE_OPEN_SRC=True -DCMAKE_INSTALL_PREFIX=${OUTPUT_PATH}" | ||||
| echo "${CMAKE_ARGS}" | echo "${CMAKE_ARGS}" | ||||
| cmake ${CMAKE_ARGS} .. | cmake ${CMAKE_ARGS} .. | ||||
| @@ -186,7 +208,6 @@ generate_package() | |||||
| done | done | ||||
| find ${OUTPUT_PATH}/${PARSER_LIB_PATH} -maxdepth 1 -name "libc_sec.so" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH} \; | find ${OUTPUT_PATH}/${PARSER_LIB_PATH} -maxdepth 1 -name "libc_sec.so" -exec cp -f {} ${OUTPUT_PATH}/${ATC_PATH} \; | ||||
| find ${OUTPUT_PATH}/${PARSER_LIB_PATH} -maxdepth 1 -name "libregister.a" -exec cp -f {} ${OUTPUT_PATH}/${ACL_PATH} \; | |||||
| tar -cf parser_lib.tar fwkacllib acllib atc | tar -cf parser_lib.tar fwkacllib acllib atc | ||||
| } | } | ||||
| @@ -11,8 +11,16 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | ||||
| endif() | endif() | ||||
| if (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/gflags/repository/archive/v2.2.2.tar.gz") | |||||
| set(MD5 "") | |||||
| else() | |||||
| set(REQ_URL "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz") | |||||
| set(MD5 "") | |||||
| endif () | |||||
| ExternalProject_Add(gflags_build | ExternalProject_Add(gflags_build | ||||
| URL https://github.com/gflags/gflags/archive/v2.2.2.tar.gz | |||||
| URL ${REQ_URL} | |||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | ||||
| #SOURCE_DIR ${PARSER_DIR}/../../third_party/gflags/src/gflags-2.2.2 | #SOURCE_DIR ${PARSER_DIR}/../../third_party/gflags/src/gflags-2.2.2 | ||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags <SOURCE_DIR> | CONFIGURE_COMMAND ${CMAKE_COMMAND} -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/gflags <SOURCE_DIR> | ||||
| @@ -5,8 +5,17 @@ endif() | |||||
| include(ExternalProject) | include(ExternalProject) | ||||
| set(JSON_SRC_DIR ${CMAKE_BINARY_DIR}/opensrc/json/include) | set(JSON_SRC_DIR ${CMAKE_BINARY_DIR}/opensrc/json/include) | ||||
| if (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/JSON-for-Modern-CPP/repository/archive/v3.6.1.zip") | |||||
| set(MD5 "5bda78ce308e6cfcf614dcf1d5ff27a7") | |||||
| set(JSON_INCLUDE_DIR "${JSON_SRC_DIR}/include") | |||||
| else() | |||||
| set(REQ_URL "https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip") | |||||
| set(MD5 "0dc903888211db3a0f170304cd9f3a89") | |||||
| set(JSON_INCLUDE_DIR ${JSON_SRC_DIR}) | |||||
| endif () | |||||
| ExternalProject_Add(json_build | ExternalProject_Add(json_build | ||||
| URL https://github.com/nlohmann/json/releases/download/v3.6.1/include.zip | |||||
| URL ${REQ_URL} | |||||
| #URL /home/txd/workspace/cloud_code/pkg/include.zip | #URL /home/txd/workspace/cloud_code/pkg/include.zip | ||||
| SOURCE_DIR ${JSON_SRC_DIR} | SOURCE_DIR ${JSON_SRC_DIR} | ||||
| CONFIGURE_COMMAND "" | CONFIGURE_COMMAND "" | ||||
| @@ -17,7 +26,7 @@ ExternalProject_Add(json_build | |||||
| add_library(json INTERFACE) | add_library(json INTERFACE) | ||||
| target_include_directories(json INTERFACE ${JSON_SRC_DIR}) | |||||
| target_include_directories(json INTERFACE ${JSON_INCLUDE_DIR}) | |||||
| add_dependencies(json json_build) | add_dependencies(json json_build) | ||||
| #set(HAVE_JSON TRUE CACHE BOOL "json build add") | #set(HAVE_JSON TRUE CACHE BOOL "json build add") | ||||
| @@ -6,8 +6,16 @@ set(ONNX_PROTO_DIR ${CMAKE_BINARY_DIR}/onnx) | |||||
| set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto) | set(ONNX_PROTO_FILE ${ONNX_PROTO_DIR}/onnx.proto) | ||||
| file(MAKE_DIRECTORY ${ONNX_PROTO_DIR}) | file(MAKE_DIRECTORY ${ONNX_PROTO_DIR}) | ||||
| if (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/ONNX/repository/archive/v1.6.0.tar.gz") | |||||
| set(MD5 "1bdbcecdd68ea8392630467646776e02") | |||||
| else() | |||||
| set(REQ_URL "https://github.com/onnx/onnx/releases/download/v1.6.0/onnx-1.6.0.tar.gz") | |||||
| set(MD5 "512f2779d6215d4a36f366b6b9acdf1e") | |||||
| endif () | |||||
| ExternalProject_Add(onnx | ExternalProject_Add(onnx | ||||
| URL https://github.com/onnx/onnx/releases/download/v1.6.0/onnx-1.6.0.tar.gz | |||||
| URL ${REQ_URL} | |||||
| #URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz | #URL /home/txd/workspace/cloud_code/pkg/onnx-1.6.0.tar.gz | ||||
| #URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345 | #URL_HASH SHA256=3b88c3fe521151651a0403c4d131cb2e0311bd28b753ef692020a432a81ce345 | ||||
| #SOURCE_DIR ${ONNX_SRC_DIR} | #SOURCE_DIR ${ONNX_SRC_DIR} | ||||
| @@ -11,10 +11,18 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | ||||
| endif() | endif() | ||||
| if (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") | |||||
| else() | |||||
| set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "3d9e32700639618a4d2d342c99d4507a") | |||||
| endif () | |||||
| set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2") | set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2") | ||||
| set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | ||||
| ExternalProject_Add(protobuf_build | ExternalProject_Add(protobuf_build | ||||
| URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz | |||||
| URL ${REQ_URL} | |||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | ||||
| #SOURCE_DIR ${PARSER_DIR}/../third_party/protobuf/src/protobuf-3.8.0 | #SOURCE_DIR ${PARSER_DIR}/../third_party/protobuf/src/protobuf-3.8.0 | ||||
| #DOWNLOAD_COMMAND ${CMAKE_COMMAND} -E copy_directory ${PARSER_DIR}/../third_party/protobuf/src/protobuf-3.8.0 <SOURCE_DIR> | #DOWNLOAD_COMMAND ${CMAKE_COMMAND} -E copy_directory ${PARSER_DIR}/../third_party/protobuf/src/protobuf-3.8.0 <SOURCE_DIR> | ||||
| @@ -8,11 +8,19 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | ||||
| endif() | endif() | ||||
| if (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") | |||||
| else() | |||||
| set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "3d9e32700639618a4d2d342c99d4507a") | |||||
| endif () | |||||
| set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2") | set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2") | ||||
| set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | ||||
| set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static) | set(PROTOBUF_STATIC_PKG_DIR ${CMAKE_INSTALL_PREFIX}/protobuf_static) | ||||
| ExternalProject_Add(protobuf_static_build | ExternalProject_Add(protobuf_static_build | ||||
| URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz | |||||
| URL ${REQ_URL} | |||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | ||||
| #SOURCE_DIR ${PARSER_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 | #SOURCE_DIR ${PARSER_DIR}/../../third_party/protobuf/src/protobuf-3.8.0 | ||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} | CONFIGURE_COMMAND ${CMAKE_COMMAND} | ||||
| @@ -12,10 +12,19 @@ if ((${CMAKE_INSTALL_PREFIX} STREQUAL /usr/local) OR | |||||
| message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | message(STATUS "No install prefix selected, default to ${CMAKE_INSTALL_PREFIX}.") | ||||
| endif() | endif() | ||||
| if (ENABLE_GITEE) | |||||
| set(REQ_URL "https://gitee.com/mirrors/protobuf_source/repository/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "eba86ae9f07ba5cfbaf8af3bc4e84236") | |||||
| else() | |||||
| set(REQ_URL "https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz") | |||||
| set(MD5 "3d9e32700639618a4d2d342c99d4507a") | |||||
| endif () | |||||
| set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2") | set(protobuf_CXXFLAGS "-Wno-maybe-uninitialized -Wno-unused-parameter -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -D_GLIBCXX_USE_CXX11_ABI=0 -O2") | ||||
| set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | set(protobuf_LDFLAGS "-Wl,-z,relro,-z,now,-z,noexecstack") | ||||
| ExternalProject_Add(protoc_build | ExternalProject_Add(protoc_build | ||||
| URL https://github.com/protocolbuffers/protobuf/archive/v3.8.0.tar.gz | |||||
| URL ${REQ_URL} | |||||
| #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | #URL /home/txd/workspace/linux_cmake/pkg/protobuf-3.8.0.tar.gz | ||||
| #SOURCE_DIR ${PARSER_DIR}/../third_party/protobuf/src/protobuf-3.8.0 | #SOURCE_DIR ${PARSER_DIR}/../third_party/protobuf/src/protobuf-3.8.0 | ||||
| CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake | CONFIGURE_COMMAND ${CMAKE_COMMAND} -Dprotobuf_WITH_ZLIB=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_CXX_FLAGS=${protobuf_CXXFLAGS} -DCMAKE_CXX_LDFLAGS=${protobuf_LDFLAGS} -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX}/protoc <SOURCE_DIR>/cmake | ||||
| @@ -1 +1 @@ | |||||
| Subproject commit cc9de48a7779cf95cab90a23db608421a691fd12 | |||||
| Subproject commit cba1ba3dbc8d933cd08ccd0dff586112c9501b75 | |||||
| @@ -131,7 +131,7 @@ target_compile_options(fmk_parser_stub PRIVATE | |||||
| ) | ) | ||||
| target_compile_definitions(fmk_parser_stub PRIVATE | target_compile_definitions(fmk_parser_stub PRIVATE | ||||
| $<$<STREQUAL:${PRODUCT_SIDE},host>:FMK_SUPPORT_DUMP> | |||||
| $<$<OR:$<STREQUAL:${PRODUCT_SIDE},host>,$<STREQUAL:${ENABLE_OPEN_SRC},True>>:FMK_SUPPORT_DUMP> | |||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | PROTOBUF_INLINE_NOT_IN_HEADERS=0 | ||||
| REUSE_MEMORY=1 | REUSE_MEMORY=1 | ||||
| FMK_HOST_INFER | FMK_HOST_INFER | ||||
| @@ -18,17 +18,16 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/omg/omg_inner_types.h" | #include "framework/omg/omg_inner_types.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| using domi::ParseParamByOpFunc; | |||||
| using domi::ParseParamFunc; | using domi::ParseParamFunc; | ||||
| using domi::ParseParamByOpFunc; | |||||
| using std::vector; | using std::vector; | ||||
| namespace ge { | namespace ge { | ||||
| @@ -55,8 +54,8 @@ Status CaffeCustomParserAdapter::ParseParams(const Message *op_src, ge::OpDescPt | |||||
| } | } | ||||
| Status CaffeCustomParserAdapter::ParseParams(const Operator &op_src, ge::OpDescPtr &op_dest) { | Status CaffeCustomParserAdapter::ParseParams(const Operator &op_src, ge::OpDescPtr &op_dest) { | ||||
| GELOGI("Caffe custom op begin to params: layer name = %s, layer type= %s ", op_src.GetName().c_str(), | |||||
| op_src.GetOpType().c_str()); | |||||
| GELOGI("Caffe custom op begin to params: layer name = %s, layer type= %s ", | |||||
| op_src.GetName().c_str(), op_src.GetOpType().c_str()); | |||||
| GE_CHECK_NOTNULL(op_dest); | GE_CHECK_NOTNULL(op_dest); | ||||
| ParseParamByOpFunc custom_op_parser = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_src.GetOpType()); | ParseParamByOpFunc custom_op_parser = domi::OpRegistry::Instance()->GetParseParamByOperatorFunc(op_src.GetOpType()); | ||||
| @@ -86,7 +85,7 @@ Status CaffeCustomParserAdapter::ParseWeights(const Message *op_src, ge::NodePtr | |||||
| bool update_in_turn = (static_cast<int64_t >(op->GetAllInputsSize()) == (layer->bottom_size() + layer->blobs_size())); | bool update_in_turn = (static_cast<int64_t >(op->GetAllInputsSize()) == (layer->bottom_size() + layer->blobs_size())); | ||||
| int start_pos = layer->bottom_size(); | int start_pos = layer->bottom_size(); | ||||
| for (int i = 0; i < layer->blobs_size(); ++i) { | for (int i = 0; i < layer->blobs_size(); ++i) { | ||||
| ge::GeTensorPtr weight = ge::parser::MakeShared<ge::GeTensor>(); | |||||
| ge::GeTensorPtr weight = ge::MakeShared<ge::GeTensor>(); | |||||
| GE_CHECK_NOTNULL(weight); | GE_CHECK_NOTNULL(weight); | ||||
| GE_CHK_STATUS_RET(ConvertWeight(layer->blobs(i), layer->name(), weight), "Convert blobs(%d) for layer %s failed", i, | GE_CHK_STATUS_RET(ConvertWeight(layer->blobs(i), layer->name(), weight), "Convert blobs(%d) for layer %s failed", i, | ||||
| layer->name().c_str()); | layer->name().c_str()); | ||||
| @@ -98,14 +97,14 @@ Status CaffeCustomParserAdapter::ParseWeights(const Message *op_src, ge::NodePtr | |||||
| bias_en = fc_params_src.bias_term();); | bias_en = fc_params_src.bias_term();); | ||||
| auto bias_shape = weight->MutableTensorDesc().GetShape(); | auto bias_shape = weight->MutableTensorDesc().GetShape(); | ||||
| // The num 0, 1, 2, 3 represet the dim index. | // The num 0, 1, 2, 3 represet the dim index. | ||||
| bool matched = bias_en && bias_shape.GetDimNum() == static_cast<size_t>(ge::parser::DIM_DEFAULT_SIZE) && | |||||
| bool matched = bias_en && bias_shape.GetDimNum() == static_cast<size_t>(ge::DIM_DEFAULT_SIZE) && | |||||
| bias_shape.GetDim(0) == 1 && bias_shape.GetDim(1) == 1 && bias_shape.GetDim(2) == 1; | bias_shape.GetDim(0) == 1 && bias_shape.GetDim(1) == 1 && bias_shape.GetDim(2) == 1; | ||||
| if (matched) { | if (matched) { | ||||
| weight->MutableTensorDesc().SetShape(ge::GeShape({bias_shape.GetDim(3)})); | weight->MutableTensorDesc().SetShape(ge::GeShape({bias_shape.GetDim(3)})); | ||||
| } | } | ||||
| matched = layer->type() == kInnerProduct && i == 0 && | matched = layer->type() == kInnerProduct && i == 0 && | ||||
| bias_shape.GetDimNum() == static_cast<size_t>(ge::parser::DIM_DEFAULT_SIZE) && | |||||
| bias_shape.GetDim(0) == 1 && bias_shape.GetDim(1) == 1; | |||||
| bias_shape.GetDimNum() == static_cast<size_t>(ge::DIM_DEFAULT_SIZE) && bias_shape.GetDim(0) == 1 && | |||||
| bias_shape.GetDim(1) == 1; | |||||
| if (matched) { | if (matched) { | ||||
| weight->MutableTensorDesc().SetShape(ge::GeShape({bias_shape.GetDim(2), bias_shape.GetDim(3)})); | weight->MutableTensorDesc().SetShape(ge::GeShape({bias_shape.GetDim(2), bias_shape.GetDim(3)})); | ||||
| } | } | ||||
| @@ -18,15 +18,13 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | #include <utility> | ||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "common/types.h" | |||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "common/util/error_manager/error_manager.h" | #include "common/util/error_manager/error_manager.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/omg/parser/parser_inner_ctx.h" | |||||
| #include "omg/omg_inner_types.h" | |||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| using namespace ge::parser; | |||||
| namespace ge { | namespace ge { | ||||
| Status CaffeDataParser::GetOutputDesc(const string &name, int dim_size, const std::vector<int64_t> &input_dims, | Status CaffeDataParser::GetOutputDesc(const string &name, int dim_size, const std::vector<int64_t> &input_dims, | ||||
| ge::OpDescPtr &op) { | ge::OpDescPtr &op) { | ||||
| @@ -50,10 +48,10 @@ Status CaffeDataParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { | |||||
| GE_CHECK_NOTNULL(layer); | GE_CHECK_NOTNULL(layer); | ||||
| GELOGD("Caffe layer name = %s, layer type= %s, parse params", layer->name().c_str(), layer->type().c_str()); | GELOGD("Caffe layer name = %s, layer type= %s, parse params", layer->name().c_str(), layer->type().c_str()); | ||||
| if (layer->type() == ge::parser::INPUT_TYPE) { | |||||
| if (layer->type() == ge::INPUT_TYPE) { | |||||
| GE_CHK_STATUS_RET(ParseParamsForInput(layer, op), "Caffe layer name = %s, layer type= %s, parse params failed", | GE_CHK_STATUS_RET(ParseParamsForInput(layer, op), "Caffe layer name = %s, layer type= %s, parse params failed", | ||||
| layer->name().c_str(), layer->type().c_str()); | layer->name().c_str(), layer->type().c_str()); | ||||
| } else if(layer->type() == ge::parser::DUMMY_DATA) { | |||||
| } else if(layer->type() == ge::DUMMY_DATA) { | |||||
| GE_CHK_STATUS_RET(ParseParamsForDummyData(layer, op), "Caffe layer name = %s, layer type= %s, parse params failed", | GE_CHK_STATUS_RET(ParseParamsForDummyData(layer, op), "Caffe layer name = %s, layer type= %s, parse params failed", | ||||
| layer->name().c_str(), layer->type().c_str()); | layer->name().c_str(), layer->type().c_str()); | ||||
| } else { | } else { | ||||
| @@ -77,12 +75,14 @@ Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *l | |||||
| } | } | ||||
| for (int i = 0; i < input_param.shape_size(); i++) { | for (int i = 0; i < input_param.shape_size(); i++) { | ||||
| const domi::caffe::BlobShape &blob_shape = input_param.shape(i); | const domi::caffe::BlobShape &blob_shape = input_param.shape(i); | ||||
| vector<int64_t> shape; | vector<int64_t> shape; | ||||
| unordered_map<string, vector<int64_t>> &shape_map = GetParserContext().input_dims; | |||||
| unordered_map<string, vector<int64_t>> &shape_map = domi::GetContext().input_dims; | |||||
| std::vector<int64_t> model_dims; | std::vector<int64_t> model_dims; | ||||
| for (auto &blob_shape_dim_temp : blob_shape.dim()) { | for (auto &blob_shape_dim_temp : blob_shape.dim()) { | ||||
| model_dims.push_back(blob_shape_dim_temp); | model_dims.push_back(blob_shape_dim_temp); | ||||
| } | } | ||||
| string name = layer->name(); | string name = layer->name(); | ||||
| GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name)); | GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name)); | ||||
| GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims.size(), model_dims, op), "Get output desc failed in layer %s", | GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims.size(), model_dims, op), "Get output desc failed in layer %s", | ||||
| @@ -90,7 +90,7 @@ Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *l | |||||
| } | } | ||||
| } else { | } else { | ||||
| // Get from external input | // Get from external input | ||||
| const ge::ParserContext &ctx = GetParserContext(); | |||||
| const ge::OmgContext &ctx = domi::GetContext(); | |||||
| std::unordered_map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | std::unordered_map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | ||||
| string name = layer->name(); | string name = layer->name(); | ||||
| auto search = input_dims.find(name); | auto search = input_dims.find(name); | ||||
| @@ -124,7 +124,7 @@ Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParamete | |||||
| const domi::caffe::BlobShape &blob_shape = dummy_data_param.shape(i); | const domi::caffe::BlobShape &blob_shape = dummy_data_param.shape(i); | ||||
| vector<int64_t> shape; | vector<int64_t> shape; | ||||
| unordered_map<string, vector<int64_t>> &shape_map = GetParserContext().input_dims; | |||||
| unordered_map<string, vector<int64_t>> &shape_map = domi::GetContext().input_dims; | |||||
| std::vector<int64_t> model_dims; | std::vector<int64_t> model_dims; | ||||
| for (auto &blob_shape_dim_temp : blob_shape.dim()) { | for (auto &blob_shape_dim_temp : blob_shape.dim()) { | ||||
| model_dims.push_back(blob_shape_dim_temp); | model_dims.push_back(blob_shape_dim_temp); | ||||
| @@ -137,7 +137,7 @@ Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParamete | |||||
| } | } | ||||
| } else { | } else { | ||||
| // Get from external input | // Get from external input | ||||
| const ge::ParserContext &ctx = GetParserContext(); | |||||
| const ge::OmgContext &ctx = domi::GetContext(); | |||||
| std::unordered_map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | std::unordered_map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | ||||
| string name = layer->name(); | string name = layer->name(); | ||||
| auto search = input_dims.find(name); | auto search = input_dims.find(name); | ||||
| @@ -18,9 +18,6 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "common/util/error_manager/error_manager.h" | #include "common/util/error_manager/error_manager.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| using namespace ge::parser; | |||||
| using domi::CAFFE; | using domi::CAFFE; | ||||
| @@ -20,17 +20,16 @@ | |||||
| #include <iostream> | #include <iostream> | ||||
| #include <sstream> | #include <sstream> | ||||
| #include <memory> | #include <memory> | ||||
| #include <algorithm> | |||||
| #include "parser/common/convert/pb2json.h" | |||||
| #include "common/convert/pb2json.h" | |||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "common/model_saver.h" | |||||
| #include "common/op_map.h" | #include "common/op_map.h" | ||||
| #include "common/util.h" | |||||
| #include "common/util/error_manager/error_manager.h" | #include "common/util/error_manager/error_manager.h" | ||||
| #include "common/ge_types.h" | |||||
| #include "common/string_util.h" | #include "common/string_util.h" | ||||
| #include "external/graph/operator_factory.h" | #include "external/graph/operator_factory.h" | ||||
| #include "external/parser/caffe_parser.h" | #include "external/parser/caffe_parser.h" | ||||
| #include "external/ge/ge_api_types.h" | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "graph/optimize/common/params.h" | #include "graph/optimize/common/params.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| @@ -47,8 +46,6 @@ | |||||
| #include "parser/caffe/caffe_op_parser.h" | #include "parser/caffe/caffe_op_parser.h" | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "parser/common/pre_checker.h" | #include "parser/common/pre_checker.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "parser/common/model_saver.h" | |||||
| #include "parser/common/acl_graph_parser_util.h" | #include "parser/common/acl_graph_parser_util.h" | ||||
| #include "parser/common/proto_file_parser.h" | #include "parser/common/proto_file_parser.h" | ||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| @@ -58,7 +55,7 @@ using domi::caffe::NetParameter; | |||||
| using domi::ParseParamByOpFunc; | using domi::ParseParamByOpFunc; | ||||
| using ge::caffe_op_map; | using ge::caffe_op_map; | ||||
| using ge::CaffeOpParser; | using ge::CaffeOpParser; | ||||
| using ge::parser::ModelSaver; | |||||
| using ge::ModelSaver; | |||||
| using ge::OpParser; | using ge::OpParser; | ||||
| using ge::OpParserFactory; | using ge::OpParserFactory; | ||||
| using ge::Pb2Json; | using ge::Pb2Json; | ||||
| @@ -77,7 +74,7 @@ using std::ifstream; | |||||
| namespace ge { | namespace ge { | ||||
| graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph) { | graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, ge::Graph &graph) { | ||||
| GE_CHECK_NOTNULL(model_file); | GE_CHECK_NOTNULL(model_file); | ||||
| GetParserContext().type = domi::CAFFE; | |||||
| domi::GetContext().type = domi::CAFFE; | |||||
| std::map<string, string> options; | std::map<string, string> options; | ||||
| options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::CAFFE))); | options.insert(std::pair<string, string>(string(ge::FRAMEWORK_TYPE), to_string(ge::CAFFE))); | ||||
| @@ -86,7 +83,7 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, | |||||
| (void)acl_graph_parse_util.AclParserInitialize(options); | (void)acl_graph_parse_util.AclParserInitialize(options); | ||||
| // Create an empty computegraph | // Create an empty computegraph | ||||
| ge::ComputeGraphPtr compute_graph = ge::parser::MakeShared<ge::ComputeGraph>("tmpGraph"); | |||||
| ge::ComputeGraphPtr compute_graph = ge::MakeShared<ge::ComputeGraph>("tmpGraph"); | |||||
| GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
| graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); | ||||
| @@ -108,10 +105,6 @@ graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| GELOGI("Weights parse success. graph: %s", graph.GetName().c_str()); | GELOGI("Weights parse success. graph: %s", graph.GetName().c_str()); | ||||
| if (acl_graph_parse_util.SetDefaultOutputNode(graph) != ge::SUCCESS) { | |||||
| GELOGE(ret, "Set graph %s default output node failed.", graph.GetName().c_str()); | |||||
| return ge::FAILED; | |||||
| } | |||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -155,15 +148,14 @@ const std::string kRepeated = "repeated"; | |||||
| const std::string kRequired = "required"; | const std::string kRequired = "required"; | ||||
| const std::string kCustom = "custom"; | const std::string kCustom = "custom"; | ||||
| const std::string kBuiltin = "built-in"; | const std::string kBuiltin = "built-in"; | ||||
| std::vector<std::string> kAddTensorIrSkipNodes = {ge::parser::DATA, ge::parser::YOLODETECTIONOUTPUT, | |||||
| ge::parser::NETOUTPUT}; | |||||
| std::vector<std::string> kAddTensorIrSkipNodes = {ge::DATA, ge::YOLODETECTIONOUTPUT, ge::NETOUTPUT}; | |||||
| const std::set<std::string> kCustomProtoLayerCommonField = {"name", "type"}; | const std::set<std::string> kCustomProtoLayerCommonField = {"name", "type"}; | ||||
| const std::set<std::string> kCaffeProtoLayerCommonField = {"name", "type", "bottom", "top", "phase", | const std::set<std::string> kCaffeProtoLayerCommonField = {"name", "type", "bottom", "top", "phase", | ||||
| "loss_weight", "param", "blobs", "propagate_down", | "loss_weight", "param", "blobs", "propagate_down", | ||||
| "include", "exclude"}; | "include", "exclude"}; | ||||
| Status CheckPathValid(const char *model_path, const string &custom_proto, string &custom_proto_path, | Status CheckPathValid(const char *model_path, const string &custom_proto, string &custom_proto_path, | ||||
| string &custom_proto_name) { | string &custom_proto_name) { | ||||
| string path_model = ge::parser::RealPath(model_path); | |||||
| string path_model = ge::RealPath(model_path); | |||||
| if (path_model.empty()) { | if (path_model.empty()) { | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, {model_path, strerror(errno)}); | ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, {model_path, strerror(errno)}); | ||||
| GELOGE(FAILED, "Invalid path of model: %s", model_path); | GELOGE(FAILED, "Invalid path of model: %s", model_path); | ||||
| @@ -219,7 +211,7 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo | |||||
| domi::caffe::LayerParameter *layer = proto_message.add_layer(); | domi::caffe::LayerParameter *layer = proto_message.add_layer(); | ||||
| GE_CHECK_NOTNULL(layer); | GE_CHECK_NOTNULL(layer); | ||||
| layer->set_name(proto_message.input(i)); | layer->set_name(proto_message.input(i)); | ||||
| layer->set_type(ge::parser::INPUT_TYPE); | |||||
| layer->set_type(ge::INPUT_TYPE); | |||||
| layer->add_top(proto_message.input(i)); | layer->add_top(proto_message.input(i)); | ||||
| domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | ||||
| @@ -248,7 +240,7 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo | |||||
| domi::caffe::LayerParameter *layer = proto_message.add_layer(); | domi::caffe::LayerParameter *layer = proto_message.add_layer(); | ||||
| GE_CHECK_NOTNULL(layer); | GE_CHECK_NOTNULL(layer); | ||||
| layer->set_name(proto_message.input(i)); | layer->set_name(proto_message.input(i)); | ||||
| layer->set_type(ge::parser::INPUT_TYPE); | |||||
| layer->set_type(ge::INPUT_TYPE); | |||||
| layer->add_top(proto_message.input(i)); | layer->add_top(proto_message.input(i)); | ||||
| domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | ||||
| @@ -263,7 +255,7 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo | |||||
| input_data_flag = true; | input_data_flag = true; | ||||
| } | } | ||||
| } else { | } else { | ||||
| const ge::ParserContext &ctx = ge::GetParserContext(); | |||||
| const ge::OmgContext &ctx = domi::GetContext(); | |||||
| std::unordered_map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | std::unordered_map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | ||||
| for (int i = 0; i < proto_message.input_size(); i++) { | for (int i = 0; i < proto_message.input_size(); i++) { | ||||
| string name = proto_message.input(i); | string name = proto_message.input(i); | ||||
| @@ -278,7 +270,7 @@ Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bo | |||||
| domi::caffe::LayerParameter *layer = proto_message.add_layer(); | domi::caffe::LayerParameter *layer = proto_message.add_layer(); | ||||
| GE_CHECK_NOTNULL(layer); | GE_CHECK_NOTNULL(layer); | ||||
| layer->set_name(name); | layer->set_name(name); | ||||
| layer->set_type(ge::parser::INPUT_TYPE); | |||||
| layer->set_type(ge::INPUT_TYPE); | |||||
| layer->add_top(proto_message.input(i)); | layer->add_top(proto_message.input(i)); | ||||
| domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | domi::caffe::InputParameter *input_param = layer->mutable_input_param(); | ||||
| @@ -343,7 +335,7 @@ Status CaffeModelParser::ParseNetModelByCustomProto(const char *model_path, cons | |||||
| Status CaffeModelParser::CustomProtoParse(const char *model_path, const string &custom_proto, | Status CaffeModelParser::CustomProtoParse(const char *model_path, const string &custom_proto, | ||||
| const string &caffe_proto, vector<ge::Operator> &operators) { | const string &caffe_proto, vector<ge::Operator> &operators) { | ||||
| string custom_proto_path = ge::parser::RealPath(custom_proto.c_str()); | |||||
| string custom_proto_path = ge::RealPath(custom_proto.c_str()); | |||||
| if (custom_proto_path.empty()) { | if (custom_proto_path.empty()) { | ||||
| GELOGW("Valid custom proto: %s does not exist, skip parsing custom proto", custom_proto.c_str()); | GELOGW("Valid custom proto: %s does not exist, skip parsing custom proto", custom_proto.c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -749,27 +741,27 @@ Status CaffeModelParser::ParseRepeatedField(const google::protobuf::Reflection * | |||||
| } | } | ||||
| void CaffeModelParser::AddOutputInfoToContext(string layer_name, int32_t top_index) { | void CaffeModelParser::AddOutputInfoToContext(string layer_name, int32_t top_index) { | ||||
| auto iter_node_name = ge::GetParserContext().out_nodes_map.find(layer_name); | |||||
| if (iter_node_name != ge::GetParserContext().out_nodes_map.end()) { | |||||
| auto iter_node_name = domi::GetContext().out_nodes_map.find(layer_name); | |||||
| if (iter_node_name != domi::GetContext().out_nodes_map.end()) { | |||||
| iter_node_name->second.emplace_back(top_index); | iter_node_name->second.emplace_back(top_index); | ||||
| } else { | } else { | ||||
| std::vector<int32_t> index_v; | std::vector<int32_t> index_v; | ||||
| index_v.emplace_back(top_index); | index_v.emplace_back(top_index); | ||||
| ge::GetParserContext().out_nodes_map.emplace(layer_name, index_v); | |||||
| domi::GetContext().out_nodes_map.emplace(layer_name, index_v); | |||||
| } | } | ||||
| ge::GetParserContext().user_out_nodes.push_back(std::make_pair(layer_name, top_index)); | |||||
| domi::GetContext().user_out_nodes.push_back(std::make_pair(layer_name, top_index)); | |||||
| } | } | ||||
| Status CaffeModelParser::ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message) { | Status CaffeModelParser::ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message) { | ||||
| if (ge::GetParserContext().user_out_nodes_top_vec.empty()) { | |||||
| if (domi::GetContext().user_out_nodes_top_vec.empty()) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| ge::GetParserContext().out_nodes_map.clear(); | |||||
| ge::GetParserContext().user_out_nodes.clear(); | |||||
| domi::GetContext().out_nodes_map.clear(); | |||||
| domi::GetContext().user_out_nodes.clear(); | |||||
| int32_t layer_count = proto_message.layer_size(); | int32_t layer_count = proto_message.layer_size(); | ||||
| const std::vector<string> &user_out_nodes_top_vec = | const std::vector<string> &user_out_nodes_top_vec = | ||||
| ge::GetParserContext().user_out_nodes_top_vec; | |||||
| domi::GetContext().user_out_nodes_top_vec; | |||||
| for (const auto &top_name : user_out_nodes_top_vec) { | for (const auto &top_name : user_out_nodes_top_vec) { | ||||
| bool find_node_falg = false; | bool find_node_falg = false; | ||||
| @@ -808,6 +800,10 @@ Status CaffeModelParser::ParseOutputNodeTopInfo(const domi::caffe::NetParameter | |||||
| Status CaffeModelParser::AddBlobsToMap(const domi::caffe::LayerParameter &layer, | Status CaffeModelParser::AddBlobsToMap(const domi::caffe::LayerParameter &layer, | ||||
| std::map<std::string, std::string> &inplace_blob_name_remapping) { | std::map<std::string, std::string> &inplace_blob_name_remapping) { | ||||
| if (layer.type() == ge::NETOUTPUT) { | |||||
| return SUCCESS; | |||||
| } | |||||
| if (layer.top_size() <= 0) { | if (layer.top_size() <= 0) { | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19011", {"opname"}, {layer.name()}); | ErrorManager::GetInstance().ATCReportErrMessage("E19011", {"opname"}, {layer.name()}); | ||||
| GELOGE(FAILED, "The output size of layer %s needs to be greater than zero.", layer.name().c_str()); | GELOGE(FAILED, "The output size of layer %s needs to be greater than zero.", layer.name().c_str()); | ||||
| @@ -966,9 +962,9 @@ Status CaffeModelParser::AddNode(const domi::caffe::LayerParameter &layer, ge::C | |||||
| } else { | } else { | ||||
| op_type = layer.type(); | op_type = layer.type(); | ||||
| // User defined duplicate name operator processing | // User defined duplicate name operator processing | ||||
| auto m_iter = ge::GetParserContext().op_conf_map.find(op_type); | |||||
| auto m_iter = domi::GetContext().op_conf_map.find(op_type); | |||||
| // User specified configuration item found | // User specified configuration item found | ||||
| if (m_iter != ge::GetParserContext().op_conf_map.end()) { | |||||
| if (m_iter != domi::GetContext().op_conf_map.end()) { | |||||
| op_type = m_iter->second; | op_type = m_iter->second; | ||||
| } | } | ||||
| // General layer layer, search optype | // General layer layer, search optype | ||||
| @@ -1057,7 +1053,7 @@ Status CaffeModelParser::AddNode(const domi::caffe::LayerParameter &layer, ge::C | |||||
| Status CaffeModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer) { | Status CaffeModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer) { | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| // Data node input and output tensordesc added in parserparam | // Data node input and output tensordesc added in parserparam | ||||
| if (op_desc->GetType() == ge::parser::DATA) { | |||||
| if (op_desc->GetType() == ge::DATA) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -1077,7 +1073,7 @@ Status CaffeModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const dom | |||||
| } | } | ||||
| // yolo v2 YoloDetectionOutput | // yolo v2 YoloDetectionOutput | ||||
| if (op_desc->GetType() == ge::parser::YOLODETECTIONOUTPUT) { | |||||
| if (op_desc->GetType() == ge::YOLODETECTIONOUTPUT) { | |||||
| ge::GeTensorDesc input_tensor; | ge::GeTensorDesc input_tensor; | ||||
| GE_RETURN_IF_ERROR(op_desc->AddInputDesc(input_tensor)); | GE_RETURN_IF_ERROR(op_desc->AddInputDesc(input_tensor)); | ||||
| GE_RETURN_IF_ERROR(op_desc->AddInputDesc(input_tensor)); | GE_RETURN_IF_ERROR(op_desc->AddInputDesc(input_tensor)); | ||||
| @@ -1086,13 +1082,41 @@ Status CaffeModelParser::AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const dom | |||||
| "while it's original input num is: %d", | "while it's original input num is: %d", | ||||
| layer.bottom_size()); | layer.bottom_size()); | ||||
| } | } | ||||
| // Netoutput node processing | |||||
| if (op_desc->GetType() == ge::NETOUTPUT) { | |||||
| size_t input_output_tensor_num = 0; | |||||
| if (!domi::GetContext().user_out_nodes.empty()) { | |||||
| // User specified output | |||||
| input_output_tensor_num = domi::GetContext().user_out_nodes.size(); | |||||
| } else { | |||||
| for (auto t_iter = top_blobs_map_.begin(); t_iter != top_blobs_map_.end(); t_iter++) { | |||||
| auto b_iter = bottom_blobs_map_.find(t_iter->first); | |||||
| // Find the output node of the network | |||||
| if (b_iter == bottom_blobs_map_.end()) { | |||||
| input_output_tensor_num += top_blobs_map_[t_iter->first].size(); | |||||
| } | |||||
| } | |||||
| } | |||||
| // add tensordesc | |||||
| GELOGD( | |||||
| "Current op type is NETOUTPUT, add additional input&output num: %zu." | |||||
| "while it's original input num is: %d, output num is: %d", | |||||
| input_output_tensor_num, layer.bottom_size(), output_tensor_num); | |||||
| for (size_t i = 0; i < input_output_tensor_num; i++) { | |||||
| ge::GeTensorDesc input_tensor; | |||||
| GE_RETURN_IF_ERROR(op_desc->AddInputDesc(input_tensor)); | |||||
| ge::GeTensorDesc output_tensor; | |||||
| GE_RETURN_IF_ERROR(op_desc->AddOutputDesc(output_tensor)); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status CaffeModelParser::AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer, | Status CaffeModelParser::AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer, | ||||
| const string &op_type) { | const string &op_type) { | ||||
| if (std::find(kAddTensorIrSkipNodes.begin(), kAddTensorIrSkipNodes.end(), op_type) != kAddTensorIrSkipNodes.end()) { | if (std::find(kAddTensorIrSkipNodes.begin(), kAddTensorIrSkipNodes.end(), op_type) != kAddTensorIrSkipNodes.end()) { | ||||
| op_desc = ge::parser::MakeShared<ge::OpDesc>(layer.name(), op_type); | |||||
| op_desc = ge::MakeShared<ge::OpDesc>(layer.name(), op_type); | |||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| Status ret = AddTensorDescToOpDesc(op_desc, layer); | Status ret = AddTensorDescToOpDesc(op_desc, layer); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| @@ -1224,8 +1248,8 @@ Status CaffeModelParser::AddEdges(ge::ComputeGraphPtr &graph) { | |||||
| bool CaffeModelParser::IsOutputTop(const string &op_name, const int32_t index) { | bool CaffeModelParser::IsOutputTop(const string &op_name, const int32_t index) { | ||||
| bool ret = false; | bool ret = false; | ||||
| auto iter = ge::GetParserContext().out_nodes_map.find(op_name); | |||||
| if (iter != ge::GetParserContext().out_nodes_map.end()) { | |||||
| auto iter = domi::GetContext().out_nodes_map.find(op_name); | |||||
| if (iter != domi::GetContext().out_nodes_map.end()) { | |||||
| std::vector<int32_t> tmp_index_v; | std::vector<int32_t> tmp_index_v; | ||||
| for (int32_t id : iter->second) { | for (int32_t id : iter->second) { | ||||
| if (index == id) { | if (index == id) { | ||||
| @@ -1236,40 +1260,53 @@ bool CaffeModelParser::IsOutputTop(const string &op_name, const int32_t index) { | |||||
| } | } | ||||
| // To prevent specifying network output again in the build phase, need to delete the output node in the map list. | // To prevent specifying network output again in the build phase, need to delete the output node in the map list. | ||||
| if (ret) { | if (ret) { | ||||
| ge::GetParserContext().out_nodes_map.erase(op_name); | |||||
| ge::GetParserContext().out_nodes_map.emplace(op_name, tmp_index_v); | |||||
| domi::GetContext().out_nodes_map.erase(op_name); | |||||
| domi::GetContext().out_nodes_map.emplace(op_name, tmp_index_v); | |||||
| } | } | ||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| Status CaffeModelParser::AddUserOutNodesTop() { | |||||
| Status CaffeModelParser::AddEdgeForUserOutNodes(ge::ComputeGraphPtr &graph) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| ge::NodePtr net_output_node = graph->FindFirstNodeMatchType(ge::NETOUTPUT); | |||||
| if (net_output_node == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "Can not find netoutput node."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| uint32_t net_output_num = net_output_node->GetAllInDataAnchorsSize(); | |||||
| int32_t index = 0; | int32_t index = 0; | ||||
| const std::vector<std::pair<std::string, int32_t>> &user_out_nodes = ge::GetParserContext().user_out_nodes; | |||||
| int net_output_num = user_out_nodes.size(); | |||||
| for (const auto &out_pair : user_out_nodes) { | |||||
| auto layer_iter = layer_tops_map_.find(out_pair.first); | |||||
| std::vector<std::pair<std::string, int32_t>> &user_out_nodes = domi::GetContext().user_out_nodes; | |||||
| for (auto &out_pair : user_out_nodes) { | |||||
| auto node_iter = node_map.find(out_pair.first); | |||||
| GELOGI("Add to output, node name: %s", out_pair.first.c_str()); | GELOGI("Add to output, node name: %s", out_pair.first.c_str()); | ||||
| if (layer_iter != layer_tops_map_.end()) { | |||||
| if (static_cast<uint32_t>(out_pair.second) >= (layer_iter->second).size()) { | |||||
| if (node_iter != node_map.end()) { | |||||
| if ((static_cast<uint32_t>(out_pair.second) >= node_iter->second->GetAllOutDataAnchorsSize()) || | |||||
| (static_cast<uint32_t>(index) >= net_output_num)) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | ErrorManager::GetInstance().ATCReportErrMessage( | ||||
| "E11016", {"opname", "outputindex", "totlaloutputindex", "inputindex", "totlalinputindex"}, | "E11016", {"opname", "outputindex", "totlaloutputindex", "inputindex", "totlalinputindex"}, | ||||
| {out_pair.first.c_str(), std::to_string(out_pair.second), | {out_pair.first.c_str(), std::to_string(out_pair.second), | ||||
| std::to_string((layer_iter->second).size()), std::to_string(index), | |||||
| std::to_string(node_iter->second->GetAllOutDataAnchorsSize()), std::to_string(index), | |||||
| std::to_string(net_output_num)}); | std::to_string(net_output_num)}); | ||||
| GELOGE(INTERNAL_ERROR, | GELOGE(INTERNAL_ERROR, | ||||
| "Add op %s to NetOutput faild, current node output index:%d should < %u. NetOutput" | "Add op %s to NetOutput faild, current node output index:%d should < %u. NetOutput" | ||||
| "input_index:%d should < %u.", | "input_index:%d should < %u.", | ||||
| out_pair.first.c_str(), out_pair.second, (layer_iter->second).size(), index, | |||||
| out_pair.first.c_str(), out_pair.second, node_iter->second->GetAllOutDataAnchorsSize(), index, | |||||
| net_output_num); | net_output_num); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| string top_name = layer_iter->second[out_pair.second]; | |||||
| auto top_node_iter = node_map.find(out_pair.first); | |||||
| if (top_node_iter != node_map.end()) { | |||||
| ge::GetParserContext().out_top_names.push_back(top_name); | |||||
| GELOGI("The top of out node [%s] is [%s]", out_pair.first.c_str(), top_name.c_str()); | |||||
| GELOGD("Start add edge for user out node: From %s:%d To %s:%d.", node_iter->second->GetName().c_str(), | |||||
| out_pair.second, net_output_node->GetName().c_str(), index); | |||||
| ge::OutDataAnchorPtr out_archor_ptr = node_iter->second->GetOutDataAnchor(out_pair.second); | |||||
| GE_CHECK_NOTNULL(out_archor_ptr); | |||||
| ge::InDataAnchorPtr in_archor_ptr = net_output_node->GetInDataAnchor(index); | |||||
| GE_CHECK_NOTNULL(in_archor_ptr); | |||||
| if (ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E11013", {"opname1", "opname2"}, | |||||
| {node_iter->second->GetName(), net_output_node->GetName()}); | |||||
| GELOGE(INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", node_iter->second->GetName().c_str(), | |||||
| net_output_node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | } | ||||
| ++index; | ++index; | ||||
| } else { | } else { | ||||
| @@ -1281,7 +1318,13 @@ Status CaffeModelParser::AddUserOutNodesTop() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_message) { | |||||
| Status CaffeModelParser::AddEdge4Output(const domi::caffe::NetParameter &proto_message, ge::ComputeGraphPtr &graph) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| ge::NodePtr node = graph->FindFirstNodeMatchType(ge::NETOUTPUT); | |||||
| GE_RETURN_WITH_LOG_IF_FALSE(node != nullptr, "Net without output, some phase failed in front."); | |||||
| int32_t index = 0; | |||||
| for (int32_t i = 0; i < proto_message.layer_size(); i++) { | for (int32_t i = 0; i < proto_message.layer_size(); i++) { | ||||
| const domi::caffe::LayerParameter &layer = proto_message.layer(i); | const domi::caffe::LayerParameter &layer = proto_message.layer(i); | ||||
| @@ -1291,7 +1334,6 @@ Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_mes | |||||
| for (int i = 0; i < layer.top_size(); i++) { | for (int i = 0; i < layer.top_size(); i++) { | ||||
| string top = layer.top(i); | string top = layer.top(i); | ||||
| string top_origin = top; | |||||
| // Handling 'inplace' scenarios | // Handling 'inplace' scenarios | ||||
| if (IsInplaceTopBlob(layer, top)) { | if (IsInplaceTopBlob(layer, top)) { | ||||
| top = RemapTopNameByLayer(layer, top, i); | top = RemapTopNameByLayer(layer, top, i); | ||||
| @@ -1313,9 +1355,21 @@ Status CaffeModelParser::AddOutputTop(const domi::caffe::NetParameter &proto_mes | |||||
| auto top_node_iter = node_map.find(layer.name()); | auto top_node_iter = node_map.find(layer.name()); | ||||
| GELOGI("output in top_blob: %s", layer.name().c_str()); | GELOGI("output in top_blob: %s", layer.name().c_str()); | ||||
| if (top_node_iter != node_map.end()) { | if (top_node_iter != node_map.end()) { | ||||
| ge::GetParserContext().out_top_names.push_back(top_origin); | |||||
| ge::GetParserContext().default_out_nodes.push_back(std::make_pair(layer.name(), (int32_t)i)); | |||||
| GELOGI("The top of out node [%s] is [%s]", layer.name().c_str(), top_origin.c_str()); | |||||
| // add edge | |||||
| // Output node, output index, input node, input index | |||||
| GELOGD("Start add edge for out node: From %s:%d To %s:%d.", top_node_iter->second->GetName().c_str(), i, | |||||
| node->GetName().c_str(), index); | |||||
| ge::OutDataAnchorPtr out_archor_ptr = top_node_iter->second->GetOutDataAnchor(i); | |||||
| GE_CHECK_NOTNULL(out_archor_ptr); | |||||
| ge::InDataAnchorPtr in_archor_ptr = node->GetInDataAnchor(index); | |||||
| GE_CHECK_NOTNULL(in_archor_ptr); | |||||
| GE_IF_BOOL_EXEC(ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E11013", {"opname1", "opname2"}, {top_node_iter->second->GetName(), node->GetName()}); | |||||
| GELOGE(INTERNAL_ERROR, "Add link failed from op[%s] to to op[%s].", | |||||
| top_node_iter->second->GetName().c_str(), node->GetName().c_str()); | |||||
| return INTERNAL_ERROR;); | |||||
| index++; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1370,7 +1424,7 @@ Status CaffeModelParser::PreCheck(const domi::caffe::NetParameter &net) { | |||||
| // validate opname | // validate opname | ||||
| string mode = "^[A-Za-z0-9./_-]+$"; | string mode = "^[A-Za-z0-9./_-]+$"; | ||||
| if (!ge::parser::ValidateStr(layer.name(), mode)) { | |||||
| if (!ge::ValidateStr(layer.name(), mode)) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E11018", {"opname"}, {layer.name()}); | ErrorManager::GetInstance().ATCReportErrMessage("E11018", {"opname"}, {layer.name()}); | ||||
| GELOGE(ge::FAILED, | GELOGE(ge::FAILED, | ||||
| "Parse caffe pbtxt validate op[%s] failed, opname can only contain " | "Parse caffe pbtxt validate op[%s] failed, opname can only contain " | ||||
| @@ -1399,7 +1453,7 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co | |||||
| domi::caffe::NetParameter proto_message; | domi::caffe::NetParameter proto_message; | ||||
| // Get Caffe network model information | // Get Caffe network model information | ||||
| if (!ge::parser::ReadProtoFromMem(data, static_cast<int>(size), &proto_message)) { | |||||
| if (!ge::ReadProtoFromMem(data, static_cast<int>(size), &proto_message)) { | |||||
| GELOGE(FAILED, "read proto from text ret fail"); | GELOGE(FAILED, "read proto from text ret fail"); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -1429,6 +1483,12 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co | |||||
| CHECK_FALSE_EXEC(ParseInput(proto_message, input_data_flag) == SUCCESS, has_error = true; | CHECK_FALSE_EXEC(ParseInput(proto_message, input_data_flag) == SUCCESS, has_error = true; | ||||
| GELOGE(FAILED, "ParseInput ret fail.")); | GELOGE(FAILED, "ParseInput ret fail.")); | ||||
| // build output layer | |||||
| domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||||
| GE_CHECK_NOTNULL(layer); | |||||
| layer->set_name(graph->GetName() + "_" + ge::NODE_NAME_NET_OUTPUT); | |||||
| layer->set_type(ge::NETOUTPUT); | |||||
| int32_t layer_count = proto_message.layer_size(); | int32_t layer_count = proto_message.layer_size(); | ||||
| std::map<std::string, std::string> inplace_blob_name_remapping; | std::map<std::string, std::string> inplace_blob_name_remapping; | ||||
| // Map of operator name and occurrence times | // Map of operator name and occurrence times | ||||
| @@ -1444,7 +1504,7 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co | |||||
| GE_CHK_BOOL_EXEC_INFO(CheckValidLayer(layer), continue, "layer phase is train, skip this layer, name:%s, type:%s.", | GE_CHK_BOOL_EXEC_INFO(CheckValidLayer(layer), continue, "layer phase is train, skip this layer, name:%s, type:%s.", | ||||
| layer.name().c_str(), layer.type().c_str()); | layer.name().c_str(), layer.type().c_str()); | ||||
| CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && (input_data_flag == true)), has_error = true; | |||||
| CHECK_FALSE_EXEC(!((layer.type() == ge::DATA_TYPE) && (input_data_flag == true)), has_error = true; | |||||
| GELOGE(FAILED, "net %s has input and data layer simultaneously.", proto_message.name().c_str())); | GELOGE(FAILED, "net %s has input and data layer simultaneously.", proto_message.name().c_str())); | ||||
| // All layer names cannot be duplicate | // All layer names cannot be duplicate | ||||
| @@ -1493,10 +1553,10 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(AddEdges(graph), "Caffe parser add edges fail."); | GE_RETURN_WITH_LOG_IF_ERROR(AddEdges(graph), "Caffe parser add edges fail."); | ||||
| if (!(ge::GetParserContext().user_out_nodes.empty())) { | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(AddUserOutNodesTop(), "Caffe parser add top_name for user out nodes failed."); | |||||
| if (!(domi::GetContext().user_out_nodes.empty())) { | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(AddEdgeForUserOutNodes(graph), "Caffe parser add edges for user out nodes failed."); | |||||
| } else { | } else { | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(AddOutputTop(proto_message), "Caffe parser add top_name for output fail."); | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(AddEdge4Output(proto_message, graph), "Caffe parser add edges for output fail."); | |||||
| } | } | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(graph->TopologicalSorting(), "Caffe parser call graph topo sort fail."); | GE_RETURN_WITH_LOG_IF_ERROR(graph->TopologicalSorting(), "Caffe parser call graph topo sort fail."); | ||||
| @@ -1540,34 +1600,6 @@ void CaffeModelParser::SaveOrigionLayerTops(domi::caffe::LayerParameter &layer) | |||||
| return; | return; | ||||
| } | } | ||||
| Status CaffeModelParser::SaveDataLayerTops(const domi::caffe::LayerParameter &layer) { | |||||
| string name = layer.name(); | |||||
| if (node_map.find(name) == node_map.end()) { | |||||
| GELOGE(FAILED, "Node can not be found by layer name: %s", name.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| ge::NodePtr node = node_map[name]; | |||||
| GE_CHECK_NOTNULL(node); | |||||
| if (node->GetType() == ge::parser::DATA) { | |||||
| if (layer.top_size() != 1) { | |||||
| GELOGE(FAILED, "Data layer[%s] top size must be 1, real size: %d", name.c_str(), layer.top_size()); | |||||
| return FAILED; | |||||
| } | |||||
| string top_name = layer.top(0); | |||||
| auto data_top_names = ge::GetParserContext().data_top_names; | |||||
| if (find(data_top_names.begin(), data_top_names.end(), top_name) != data_top_names.end()) { | |||||
| GELOGE(FAILED, "Different data can not have same top name: %s.", top_name.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| ge::GetParserContext().data_top_names.push_back(top_name); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &graph) { | Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &graph) { | ||||
| bool has_error = false; | bool has_error = false; | ||||
| GE_CHECK_NOTNULL(model_path); | GE_CHECK_NOTNULL(model_path); | ||||
| @@ -1626,20 +1658,25 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||||
| CHECK_FALSE_EXEC(ParseInput(proto_message, input_data_flag) == SUCCESS, has_error = true; | CHECK_FALSE_EXEC(ParseInput(proto_message, input_data_flag) == SUCCESS, has_error = true; | ||||
| GELOGE(FAILED, "ParseInput ret fail.")); | GELOGE(FAILED, "ParseInput ret fail.")); | ||||
| // build output layer | |||||
| domi::caffe::LayerParameter *layer = proto_message.add_layer(); | |||||
| GE_CHECK_NOTNULL(layer); | |||||
| layer->set_name(graph->GetName() + "_" + ge::NODE_NAME_NET_OUTPUT); | |||||
| layer->set_type(ge::NETOUTPUT); | |||||
| int32_t layer_count = proto_message.layer_size(); | int32_t layer_count = proto_message.layer_size(); | ||||
| if (!ge::GetParserContext().user_out_nodes_top_vec.empty()) { | |||||
| if (!domi::GetContext().user_out_nodes_top_vec.empty()) { | |||||
| GELOGW("The out_put info has top_name items."); | GELOGW("The out_put info has top_name items."); | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(ParseOutputNodeTopInfo(proto_message), | GE_RETURN_WITH_LOG_IF_ERROR(ParseOutputNodeTopInfo(proto_message), | ||||
| "Caffe parser parse output node-top info failed."); | "Caffe parser parse output node-top info failed."); | ||||
| ge::GetParserContext().user_out_nodes_top_vec.clear(); | |||||
| domi::GetContext().user_out_nodes_top_vec.clear(); | |||||
| } | } | ||||
| std::map<std::string, std::string> inplace_blob_name_remapping; | std::map<std::string, std::string> inplace_blob_name_remapping; | ||||
| // Map of operator name and occurrence times | // Map of operator name and occurrence times | ||||
| std::map<std::string, int32_t> layer_name_map; | std::map<std::string, int32_t> layer_name_map; | ||||
| GetParserContext().data_top_names.clear(); | |||||
| // <layername,paramnames> | // <layername,paramnames> | ||||
| std::map<std::string, std::vector<std::string>> layer_params_map; | std::map<std::string, std::vector<std::string>> layer_params_map; | ||||
| // same param name set <paramnames,layernames> | // same param name set <paramnames,layernames> | ||||
| @@ -1649,7 +1686,7 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||||
| GE_CHK_BOOL_EXEC_INFO(CheckValidLayer(layer), continue, "layer phase is train, skip this layer, name:%s, type:%s.", | GE_CHK_BOOL_EXEC_INFO(CheckValidLayer(layer), continue, "layer phase is train, skip this layer, name:%s, type:%s.", | ||||
| layer.name().c_str(), layer.type().c_str()); | layer.name().c_str(), layer.type().c_str()); | ||||
| CHECK_FALSE_EXEC(!((layer.type() == ge::parser::DATA_TYPE) && (input_data_flag == true)), has_error = true; | |||||
| CHECK_FALSE_EXEC(!((layer.type() == ge::DATA_TYPE) && (input_data_flag == true)), has_error = true; | |||||
| GELOGE(FAILED, "net %s has input and data layer simultaneously.", proto_message.name().c_str())); | GELOGE(FAILED, "net %s has input and data layer simultaneously.", proto_message.name().c_str())); | ||||
| // All layer names cannot be duplicate | // All layer names cannot be duplicate | ||||
| @@ -1686,11 +1723,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(AddBlobsToMap(layer, inplace_blob_name_remapping), | GE_RETURN_WITH_LOG_IF_ERROR(AddBlobsToMap(layer, inplace_blob_name_remapping), | ||||
| "Caffe parser add blobs to map ret fail."); | "Caffe parser add blobs to map ret fail."); | ||||
| if (SaveDataLayerTops(layer) != SUCCESS) { | |||||
| GELOGE(FAILED, "Caffe parse: save data layer tops failed."); | |||||
| return FAILED; | |||||
| } | |||||
| } | } | ||||
| // Find a layer with the same param name and save it to graph | // Find a layer with the same param name and save it to graph | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(FindShareParamLayers(layer_params_map), | GE_RETURN_WITH_LOG_IF_ERROR(FindShareParamLayers(layer_params_map), | ||||
| "Caffe parser find share param layers map ret fail."); | "Caffe parser find share param layers map ret fail."); | ||||
| @@ -1702,12 +1736,13 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(AddEdges(graph), "Caffe parser add edges fail."); | GE_RETURN_WITH_LOG_IF_ERROR(AddEdges(graph), "Caffe parser add edges fail."); | ||||
| if (!(ge::GetParserContext().user_out_nodes.empty())) { | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(AddUserOutNodesTop(), "Caffe parser add top_name for user out nodes failed."); | |||||
| if (!(domi::GetContext().user_out_nodes.empty())) { | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(AddEdgeForUserOutNodes(graph), "Caffe parser add edges for user out nodes failed."); | |||||
| } else { | } else { | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(AddOutputTop(proto_message), "Caffe parser add top_name for output fail."); | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(AddEdge4Output(proto_message, graph), "Caffe parser add edges for output fail."); | |||||
| } | } | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(graph->TopologicalSorting(), "Caffe parser call graph topo sort fail."); | GE_RETURN_WITH_LOG_IF_ERROR(graph->TopologicalSorting(), "Caffe parser call graph topo sort fail."); | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(GetLeafNodeTops(graph), "Caffe parser get out nodes top names failed."); | |||||
| auto nodes = graph->GetDirectNode(); | auto nodes = graph->GetDirectNode(); | ||||
| GELOGI("graph node size = %zu.", nodes.size()); | GELOGI("graph node size = %zu.", nodes.size()); | ||||
| @@ -1800,7 +1835,7 @@ Status CaffeWeightsParser::ParseFromMemory(const char *data, uint32_t size, ge:: | |||||
| // Resolve proto file to netparameter | // Resolve proto file to netparameter | ||||
| NetParameter proto; | NetParameter proto; | ||||
| bool success = ge::parser::ReadProtoFromArray(reinterpret_cast<const char *>(data), static_cast<int>(size), &proto); | |||||
| bool success = ge::ReadProtoFromArray(reinterpret_cast<const char *>(data), static_cast<int>(size), &proto); | |||||
| if (!success) { | if (!success) { | ||||
| GELOGE(domi::PARSE_WEIGHTS_FAILED, "ReadProto from Memory fail"); | GELOGE(domi::PARSE_WEIGHTS_FAILED, "ReadProto from Memory fail"); | ||||
| return domi::PARSE_WEIGHTS_FAILED; | return domi::PARSE_WEIGHTS_FAILED; | ||||
| @@ -1848,7 +1883,7 @@ Status CaffeWeightsParser::Parse(const char *file, ge::ComputeGraphPtr &graph) { | |||||
| GELOGD("caffe_proto_path:%s custom_proto_path:%s", caffe_proto_path.c_str(), custom_proto_path.c_str()); | GELOGD("caffe_proto_path:%s custom_proto_path:%s", caffe_proto_path.c_str(), custom_proto_path.c_str()); | ||||
| string fusion_proto_file; | string fusion_proto_file; | ||||
| string custom_proto_file = ge::parser::RealPath(custom_proto_path.c_str()); | |||||
| string custom_proto_file = ge::RealPath(custom_proto_path.c_str()); | |||||
| if (custom_proto_file.empty()) { | if (custom_proto_file.empty()) { | ||||
| GELOGW("custom_proto_path:%s is not existed", custom_proto_path.c_str()); | GELOGW("custom_proto_path:%s is not existed", custom_proto_path.c_str()); | ||||
| fusion_proto_file = caffe_proto_path; | fusion_proto_file = caffe_proto_path; | ||||
| @@ -1860,7 +1895,7 @@ Status CaffeWeightsParser::Parse(const char *file, ge::ComputeGraphPtr &graph) { | |||||
| } | } | ||||
| } | } | ||||
| string fusion_proto_path = ge::parser::RealPath(fusion_proto_file.c_str()); | |||||
| string fusion_proto_path = ge::RealPath(fusion_proto_file.c_str()); | |||||
| GELOGI("Get fusion proto file[%s]-[%s].", fusion_proto_file.c_str(), fusion_proto_path.c_str()); | GELOGI("Get fusion proto file[%s]-[%s].", fusion_proto_file.c_str(), fusion_proto_path.c_str()); | ||||
| if (fusion_proto_path.empty()) { | if (fusion_proto_path.empty()) { | ||||
| GELOGE(FAILED, "Fusion proto file path [%s]-[%s] is not real existed.", fusion_proto_file.c_str(), | GELOGE(FAILED, "Fusion proto file path [%s]-[%s] is not real existed.", fusion_proto_file.c_str(), | ||||
| @@ -1913,7 +1948,7 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con | |||||
| google::protobuf::Message *message = proto->New(); | google::protobuf::Message *message = proto->New(); | ||||
| GE_CHECK_NOTNULL(message); | GE_CHECK_NOTNULL(message); | ||||
| if (!ge::parser::ReadProtoFromBinaryFile(weight_path, message)) { | |||||
| if (!ge::ReadProtoFromBinaryFile(weight_path, message)) { | |||||
| delete message; | delete message; | ||||
| message = nullptr; | message = nullptr; | ||||
| ErrorManager::GetInstance().ATCReportErrMessage( | ErrorManager::GetInstance().ATCReportErrMessage( | ||||
| @@ -2303,7 +2338,7 @@ Status CaffeWeightsParser::CheckNodes(ge::ComputeGraphPtr &graph) { | |||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| for (const auto &in_anchor_ptr : node->GetAllInDataAnchors()) { | for (const auto &in_anchor_ptr : node->GetAllInDataAnchors()) { | ||||
| if (op_desc->GetType() == ge::parser::DATA || op_desc->GetType() == ge::parser::CONSTANT) { | |||||
| if (op_desc->GetType() == ge::DATA || op_desc->GetType() == ge::CONSTANT) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto index = in_anchor_ptr->GetIdx(); | auto index = in_anchor_ptr->GetIdx(); | ||||
| @@ -2418,6 +2453,27 @@ Status CaffeWeightsParser::ConvertNetParameter(const NetParameter ¶m, ge::Co | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status CaffeModelParser::GetLeafNodeTops(ge::ComputeGraphPtr &graph) { | |||||
| auto netout = graph->FindFirstNodeMatchType(ge::NETOUTPUT); | |||||
| GE_CHECK_NOTNULL(netout); | |||||
| for (const auto &in_anchor : netout->GetAllInDataAnchors()) { | |||||
| auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(peer_out_data_anchor); | |||||
| auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(peer_out_data_node); | |||||
| int idx = peer_out_data_anchor->GetIdx(); | |||||
| string node_name = peer_out_data_node->GetName(); | |||||
| auto layer_iter = layer_tops_map_.find(node_name); | |||||
| if (layer_iter != layer_tops_map_.end()) { | |||||
| domi::GetContext().out_top_names.push_back(layer_iter->second[idx]); | |||||
| GELOGI("The top of out node [%s] is [%s]", node_name.c_str(), layer_iter->second[idx].c_str()); | |||||
| } else { | |||||
| GELOGW("The out node [%s] can not find its top.", node_name.c_str()); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status CaffeModelParser::ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) { | Status CaffeModelParser::ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -279,12 +279,12 @@ class CaffeModelParser : public domi::ModelParser { | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| * @brief Add top name information to graph | |||||
| * @param [in|out] proto_message | |||||
| * @brief Add edge information to graph | |||||
| * @param [in|out] graph graph for saving model information | |||||
| * @return SUCCESS add successfully | * @return SUCCESS add successfully | ||||
| * @return FAILED add failed | * @return FAILED add failed | ||||
| */ | */ | ||||
| Status AddOutputTop(const domi::caffe::NetParameter &proto_message); | |||||
| Status AddEdge4Output(const domi::caffe::NetParameter &proto_message, ge::ComputeGraphPtr &graph); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -324,7 +324,7 @@ class CaffeModelParser : public domi::ModelParser { | |||||
| Status AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer, | Status AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer, | ||||
| const string &op_type); | const string &op_type); | ||||
| Status AddUserOutNodesTop(); | |||||
| Status AddEdgeForUserOutNodes(ge::ComputeGraphPtr &graph); | |||||
| std::string RemapTopNameByLayer(const domi::caffe::LayerParameter &layer, const std::string &top_name, int index); | std::string RemapTopNameByLayer(const domi::caffe::LayerParameter &layer, const std::string &top_name, int index); | ||||
| @@ -335,6 +335,8 @@ class CaffeModelParser : public domi::ModelParser { | |||||
| Status ParseOpParam(const domi::caffe::LayerParameter &layer, ge::OpDescPtr &op, | Status ParseOpParam(const domi::caffe::LayerParameter &layer, ge::OpDescPtr &op, | ||||
| std::shared_ptr<ge::OpParser> &op_parser); | std::shared_ptr<ge::OpParser> &op_parser); | ||||
| Status GetLeafNodeTops(ge::ComputeGraphPtr &graph); | |||||
| void SaveOrigionLayerTops(domi::caffe::LayerParameter &layer); | void SaveOrigionLayerTops(domi::caffe::LayerParameter &layer); | ||||
| Status ReorderInput(domi::caffe::NetParameter &net); | Status ReorderInput(domi::caffe::NetParameter &net); | ||||
| @@ -343,8 +345,6 @@ class CaffeModelParser : public domi::ModelParser { | |||||
| Status ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message); | Status ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message); | ||||
| Status SaveDataLayerTops(const domi::caffe::LayerParameter &layer); | |||||
| std::map<std::string, ge::NodePtr> node_map; | std::map<std::string, ge::NodePtr> node_map; | ||||
| // key: blob name, value: layer name and index | // key: blob name, value: layer name and index | ||||
| @@ -17,16 +17,14 @@ | |||||
| #include "parser/caffe/caffe_reshape_parser.h" | #include "parser/caffe/caffe_reshape_parser.h" | ||||
| #include <vector> | #include <vector> | ||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "common/op/op_parser_util.h" | #include "common/op/op_parser_util.h" | ||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "proto/om.pb.h" | #include "proto/om.pb.h" | ||||
| using namespace ge::parser; | |||||
| using domi::CAFFE; | using domi::CAFFE; | ||||
| namespace ge { | namespace ge { | ||||
| @@ -109,7 +107,7 @@ Status CaffeReshapeParser::AddConstInput(ge::NodePtr &node) { | |||||
| } | } | ||||
| // construct GeTensorPtr | // construct GeTensorPtr | ||||
| ge::GeTensorPtr constTensor = ge::parser::MakeShared<ge::GeTensor>(); | |||||
| ge::GeTensorPtr constTensor = ge::MakeShared<ge::GeTensor>(); | |||||
| GE_CHECK_NOTNULL(constTensor); | GE_CHECK_NOTNULL(constTensor); | ||||
| constTensor->SetTensorDesc(const_desc); | constTensor->SetTensorDesc(const_desc); | ||||
| @@ -8,8 +8,7 @@ set(SRC_LIST | |||||
| "parser_inner_ctx.cc" | "parser_inner_ctx.cc" | ||||
| "proto_file_parser.cc" | "proto_file_parser.cc" | ||||
| "acl_graph_parser_util.cc" | "acl_graph_parser_util.cc" | ||||
| "tbe_plugin_loader.cc" | |||||
| "model_saver.cc" | |||||
| "../../../ge/common/model_saver.cc" | |||||
| "../tensorflow/tensorflow_custom_parser_adapter.cc" | "../tensorflow/tensorflow_custom_parser_adapter.cc" | ||||
| "../tensorflow/tensorflow_fusion_custom_parser_adapter.cc" | "../tensorflow/tensorflow_fusion_custom_parser_adapter.cc" | ||||
| "../tensorflow/tensorflow_fusion_op_parser.cc" | "../tensorflow/tensorflow_fusion_op_parser.cc" | ||||
| @@ -20,10 +19,9 @@ set(SRC_LIST | |||||
| "op_def/op_schema.cc" | "op_def/op_schema.cc" | ||||
| "op_def/operator.cc" | "op_def/operator.cc" | ||||
| "op_map.cc" | "op_map.cc" | ||||
| "parser_types.cc" | |||||
| "pass_manager.cc" | |||||
| "parser_fp16_t.cc" | |||||
| "thread_pool.cc" | |||||
| "../../../ge/graph/passes/pass_manager.cc" | |||||
| "../../../ge/common/thread_pool.cc" | |||||
| "parser_utils.cc" | |||||
| ) | ) | ||||
| ############ libparser_common.so ############ | ############ libparser_common.so ############ | ||||
| @@ -18,37 +18,20 @@ | |||||
| #include <dlfcn.h> | #include <dlfcn.h> | ||||
| #include <cstdlib> | #include <cstdlib> | ||||
| #include <fstream> | |||||
| #include <regex.h> | |||||
| #include <ctime> | |||||
| #include "common/string_util.h" | #include "common/string_util.h" | ||||
| #include "common/types.h" | |||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "common/ge/tbe_plugin_manager.h" | |||||
| #include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
| #include "common/util.h" | |||||
| #include "ge/ge_api_types.h" | #include "ge/ge_api_types.h" | ||||
| #include "graph/opsproto_manager.h" | #include "graph/opsproto_manager.h" | ||||
| #include "omg/parser/parser_inner_ctx.h" | #include "omg/parser/parser_inner_ctx.h" | ||||
| #include "tbe_plugin_loader.h" | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "parser/common/register_tbe.h" | #include "parser/common/register_tbe.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "google/protobuf/io/coded_stream.h" | |||||
| #include "google/protobuf/io/zero_copy_stream_impl.h" | |||||
| using google::protobuf::io::CodedInputStream; | |||||
| using google::protobuf::io::FileInputStream; | |||||
| using google::protobuf::io::ZeroCopyInputStream; | |||||
| using namespace ge::parser; | |||||
| namespace { | namespace { | ||||
| /// The maximum length of the file. | |||||
| /// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1 | |||||
| const int kMaxFileSizeLimit = INT_MAX; | |||||
| const int kMaxBuffSize = 256; | |||||
| const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. | |||||
| const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M | |||||
| static string GetSoPath() { | static string GetSoPath() { | ||||
| Dl_info dl_info; | Dl_info dl_info; | ||||
| if (dladdr(reinterpret_cast<void *>(&GetSoPath), &dl_info) == 0) { | if (dladdr(reinterpret_cast<void *>(&GetSoPath), &dl_info) == 0) { | ||||
| @@ -77,7 +60,7 @@ static void GetOpsProtoPath(string &opsproto_path) { | |||||
| const char *path_env = std::getenv("ASCEND_OPP_PATH"); | const char *path_env = std::getenv("ASCEND_OPP_PATH"); | ||||
| if (path_env != nullptr) { | if (path_env != nullptr) { | ||||
| string path = path_env; | string path = path_env; | ||||
| string file_path = ge::parser::RealPath(path.c_str()); | |||||
| string file_path = ge::RealPath(path.c_str()); | |||||
| if (file_path.empty()) { | if (file_path.empty()) { | ||||
| GELOGE(ge::FAILED, "File path %s is invalid.", path.c_str()); | GELOGE(ge::FAILED, "File path %s is invalid.", path.c_str()); | ||||
| return; | return; | ||||
| @@ -125,7 +108,7 @@ domi::Status AclGrphParseUtil::GetOutputLeaf(NodePtr node, | |||||
| void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | void AclGrphParseUtil::GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | ||||
| std::vector<std::string> &output_nodes_name) { | std::vector<std::string> &output_nodes_name) { | ||||
| output_nodes_name.clear(); | output_nodes_name.clear(); | ||||
| if (ge::GetParserContext().out_top_names.empty()) { | |||||
| if (domi::GetContext().out_top_names.empty()) { | |||||
| // tf process, no top name. | // tf process, no top name. | ||||
| for (const auto output_node_info : output_nodes_info) { | for (const auto output_node_info : output_nodes_info) { | ||||
| std::string node_name = output_node_info.first->GetName(); | std::string node_name = output_node_info.first->GetName(); | ||||
| @@ -159,7 +142,7 @@ domi::Status AclGrphParseUtil::SetDefaultOutputNode(ge::Graph &graph) { | |||||
| AclGrphParseUtil::GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name); | AclGrphParseUtil::GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name); | ||||
| compute_graph->SetGraphOutNodesInfo(output_nodes_info); | compute_graph->SetGraphOutNodesInfo(output_nodes_info); | ||||
| ge::GetParserContext().net_out_nodes = output_nodes_name; | |||||
| domi::GetContext().net_out_nodes = output_nodes_name; | |||||
| GELOGI("Set graph %s default output node success.", graph.GetName().c_str()); | GELOGI("Set graph %s default output node success.", graph.GetName().c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -211,7 +194,7 @@ domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, s | |||||
| } | } | ||||
| // load custom op plugin | // load custom op plugin | ||||
| TBEPluginLoader::Instance().LoadPluginSo(options); | |||||
| TBEPluginManager::Instance().LoadPluginSo(options); | |||||
| // load and save custom op proto for prediction | // load and save custom op proto for prediction | ||||
| (void)LoadOpsProtoLib(); | (void)LoadOpsProtoLib(); | ||||
| @@ -239,254 +222,4 @@ domi::Status AclGrphParseUtil::AclParserInitialize(const std::map<std::string, s | |||||
| GELOGT(TRACE_STOP, "AclParserInitialize finished"); | GELOGT(TRACE_STOP, "AclParserInitialize finished"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| namespace parser { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) { | |||||
| if (path == nullptr) { | |||||
| GELOGE(ge::FAILED, "path pointer is NULL."); | |||||
| return ""; | |||||
| } | |||||
| if (strlen(path) >= PATH_MAX) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, {path, std::to_string(PATH_MAX)}); | |||||
| GELOGE(ge::FAILED, "Path[%s] len is too long, it must be less than %d", path, PATH_MAX); | |||||
| return ""; | |||||
| } | |||||
| // Nullptr is returned when the path does not exist or there is no permission | |||||
| // Return absolute path when path is accessible | |||||
| std::string res; | |||||
| char resolved_path[PATH_MAX] = {0}; | |||||
| if (realpath(path, resolved_path) != nullptr) { | |||||
| res = resolved_path; | |||||
| } | |||||
| return res; | |||||
| } | |||||
| // Get file length | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY long GetFileLength(const std::string &input_file) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(input_file.empty(), return -1, "input_file path is null."); | |||||
| std::string real_path = RealPath(input_file.c_str()); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str()); | |||||
| unsigned long long file_length = 0; | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, | |||||
| {input_file, strerror(errno)}); | |||||
| return -1, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno)); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0), | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file}); | |||||
| return -1, "File[%s] size is 0, not valid.", input_file.c_str()); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(file_length > kMaxFileSizeLimit, | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E19016", {"filepath", "filesize", "maxlen"}, | |||||
| {input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)}); | |||||
| return -1, "File[%s] size %lld is out of limit: %d.", | |||||
| input_file.c_str(), file_length, kMaxFileSizeLimit); | |||||
| return static_cast<long>(file_length); | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() { | |||||
| struct timeval tv{}; | |||||
| int ret = gettimeofday(&tv, nullptr); | |||||
| GE_LOGE_IF(ret != 0, "Func gettimeofday may failed: ret=%d", ret); | |||||
| auto total_use_time = tv.tv_usec + tv.tv_sec * 1000000; // 1000000: seconds to microseconds | |||||
| return static_cast<uint64_t>(total_use_time); | |||||
| } | |||||
| static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Message *proto) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto == nullptr, | |||||
| return false, "incorrect parameter. nullptr == proto"); | |||||
| coded_stream.SetTotalBytesLimit(kProtoReadBytesLimit, kWarningThreshold); | |||||
| return proto->ParseFromCodedStream(&coded_stream); | |||||
| } | |||||
| /** @ingroup domi_common | |||||
| * @brief Read all data from binary file | |||||
| * @param [in] file_name File path | |||||
| * @param [out] buffer The address of the output memory, which needs to be released by the caller | |||||
| * @param [out] length Output memory size | |||||
| * @return false fail | |||||
| * @return true success | |||||
| */ | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, | |||||
| int &length) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_name == nullptr), return false, "incorrect parameter. file is nullptr"); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((buffer == nullptr), return false, "incorrect parameter. buffer is nullptr"); | |||||
| std::string real_path = RealPath(file_name); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "file path '%s' not valid", file_name); | |||||
| std::ifstream file(real_path.c_str(), std::ios::binary | std::ios::ate); | |||||
| if (!file.is_open()) { | |||||
| GELOGE(ge::FAILED, "Read file %s failed.", file_name); | |||||
| return false; | |||||
| } | |||||
| length = static_cast<int>(file.tellg()); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((length <= 0), file.close(); return false, "file length <= 0"); | |||||
| file.seekg(0, std::ios::beg); | |||||
| *buffer = new(std::nothrow) char[length](); | |||||
| GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(*buffer == nullptr, false, file.close(), "new an object failed."); | |||||
| file.read(*buffer, length); | |||||
| file.close(); | |||||
| return true; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(const char *file, Message *proto) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || proto == nullptr), | |||||
| return false, | |||||
| "Input parameter file or proto is nullptr!"); | |||||
| std::string real_path = RealPath(file); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), | |||||
| return false, "pb file path '%s' not valid", file); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); | |||||
| std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); | |||||
| if (!fs.is_open()) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file, "ifstream is_open failed"}); | |||||
| GELOGE(ge::FAILED, "Open real path[%s] failed.", file); | |||||
| return false; | |||||
| } | |||||
| google::protobuf::io::IstreamInputStream istream(&fs); | |||||
| google::protobuf::io::CodedInputStream coded_stream(&istream); | |||||
| bool ret = ReadProtoFromCodedInputStream(coded_stream, proto); | |||||
| fs.close(); | |||||
| if (!ret) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"file"}, {file}); | |||||
| GELOGE(ge::FAILED, "Parse file[%s] failed.", file); | |||||
| return ret; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromArray(const void *data, int size, Message *proto) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((proto == nullptr || data == nullptr || size == 0), return false, | |||||
| "incorrect parameter. proto is nullptr || data is nullptr || size is 0"); | |||||
| google::protobuf::io::CodedInputStream coded_stream(reinterpret_cast<uint8_t *>(const_cast<void *>(data)), size); | |||||
| return ReadProtoFromCodedInputStream(coded_stream, proto); | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const char *file, | |||||
| google::protobuf::Message *message) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || message == nullptr), return false, | |||||
| "incorrect parameter. nullptr == file || nullptr == message"); | |||||
| std::string real_path = RealPath(file); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, | |||||
| {file, strerror(errno)}); | |||||
| return false, "Path[%s]'s realpath is empty, errmsg[%s]", file, | |||||
| strerror(errno)); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); | |||||
| std::ifstream fs(real_path.c_str(), std::ifstream::in); | |||||
| if (!fs.is_open()) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19017", {"realpth", "protofile"}, {real_path, file}); | |||||
| GELOGE(ge::FAILED, | |||||
| "Fail to open proto file real path is '%s' when orginal file path is '%s'.", real_path.c_str(), file); | |||||
| return false; | |||||
| } | |||||
| google::protobuf::io::IstreamInputStream input(&fs); | |||||
| bool ret = google::protobuf::TextFormat::Parse(&input, message); | |||||
| GE_IF_BOOL_EXEC(!ret, | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19018", {"protofile"}, {file}); | |||||
| GELOGE(ret, "Parse file[%s] through [google::protobuf::TextFormat::Parse] failed, " | |||||
| "please check whether the file is a valid protobuf format file.", file)); | |||||
| fs.close(); | |||||
| return ret; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const char *data, int size, | |||||
| google::protobuf::Message *message) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((data == nullptr || message == nullptr), return false, | |||||
| "incorrect parameter. data is nullptr || message is nullptr"); | |||||
| std::string str(data, static_cast<size_t>(size)); | |||||
| std::istringstream fs(str); | |||||
| google::protobuf::io::IstreamInputStream input(&fs); | |||||
| bool ret = google::protobuf::TextFormat::Parse(&input, message); | |||||
| GE_IF_BOOL_EXEC( | |||||
| !ret, GELOGE(ret, "Call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file.")); | |||||
| return ret; | |||||
| } | |||||
| /// | |||||
| /// @brief get the Original Type of FrameworkOp | |||||
| /// @param [in] node | |||||
| /// @param [out] type | |||||
| /// @return Status | |||||
| /// | |||||
| Status GetOriginalType(const ge::NodePtr &node, string &type) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| type = node->GetType(); | |||||
| GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS); | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| bool ret = ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); | |||||
| if (!ret) { | |||||
| GELOGE(INTERNAL_ERROR, "Get FrameWorkOp original type [%s]", type.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GELOGD("Get FrameWorkOp original type [%s]", type.c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::string &mode) { | |||||
| char ebuff[kMaxBuffSize]; | |||||
| regex_t reg; | |||||
| int cflags = REG_EXTENDED | REG_NOSUB; | |||||
| int ret = regcomp(®, mode.c_str(), cflags); | |||||
| if (ret) { | |||||
| regerror(ret, ®, ebuff, kMaxBuffSize); | |||||
| GELOGW("regcomp failed, reason: %s", ebuff); | |||||
| regfree(®); | |||||
| return true; | |||||
| } | |||||
| ret = regexec(®, str.c_str(), 0, nullptr, 0); | |||||
| if (ret) { | |||||
| regerror(ret, ®, ebuff, kMaxBuffSize); | |||||
| GELOGE(ge::PARAM_INVALID, "regexec failed, reason: %s", ebuff); | |||||
| regfree(®); | |||||
| return false; | |||||
| } | |||||
| regfree(®); | |||||
| return true; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string CurrentTimeInStr() { | |||||
| std::time_t now = std::time(nullptr); | |||||
| std::tm *ptm = std::localtime(&now); | |||||
| if (ptm == nullptr) { | |||||
| GELOGE(ge::FAILED, "Localtime failed."); | |||||
| return ""; | |||||
| } | |||||
| const int kTimeBufferLen = 32; | |||||
| char buffer[kTimeBufferLen + 1] = {0}; | |||||
| // format: 20171122042550 | |||||
| std::strftime(buffer, kTimeBufferLen, "%Y%m%d%H%M%S", ptm); | |||||
| return std::string(buffer); | |||||
| } | |||||
| } // namespace parser | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -19,17 +19,10 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <string> | #include <string> | ||||
| #include <google/protobuf/text_format.h> | |||||
| #include <sstream> | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "register/register_error_codes.h" | |||||
| #include "common/types.h" | |||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| using google::protobuf::Message; | |||||
| class AclGrphParseUtil { | class AclGrphParseUtil { | ||||
| public: | public: | ||||
| AclGrphParseUtil() {} | AclGrphParseUtil() {} | ||||
| @@ -45,189 +38,6 @@ class AclGrphParseUtil { | |||||
| void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | ||||
| std::vector<std::string> &output_nodes_name); | std::vector<std::string> &output_nodes_name); | ||||
| }; | }; | ||||
| namespace parser { | |||||
| /// | |||||
| /// @ingroup: domi_common | |||||
| /// @brief: get length of file | |||||
| /// @param [in] input_file: path of file | |||||
| /// @return long: File length. If the file length fails to be obtained, the value -1 is returned. | |||||
| /// | |||||
| extern long GetFileLength(const std::string &input_file); | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief Absolute path for obtaining files. | |||||
| /// @param [in] path of input file | |||||
| /// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned | |||||
| /// | |||||
| std::string RealPath(const char *path); | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief Obtains the absolute time (timestamp) of the current system. | |||||
| /// @return Timestamp, in microseconds (US) | |||||
| /// | |||||
| /// | |||||
| uint64_t GetCurrentTimestamp(); | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief Reads all data from a binary file. | |||||
| /// @param [in] file_name path of file | |||||
| /// @param [out] buffer Output memory address, which needs to be released by the caller. | |||||
| /// @param [out] length Output memory size | |||||
| /// @return false fail | |||||
| /// @return true success | |||||
| /// | |||||
| bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length); | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief proto file in bianary format | |||||
| /// @param [in] file path of proto file | |||||
| /// @param [out] proto memory for storing the proto file | |||||
| /// @return true success | |||||
| /// @return false fail | |||||
| /// | |||||
| bool ReadProtoFromBinaryFile(const char *file, Message *proto); | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief Reads the proto structure from an array. | |||||
| /// @param [in] data proto data to be read | |||||
| /// @param [in] size proto data size | |||||
| /// @param [out] proto Memory for storing the proto file | |||||
| /// @return true success | |||||
| /// @return false fail | |||||
| /// | |||||
| bool ReadProtoFromArray(const void *data, int size, Message *proto); | |||||
| /// | |||||
| /// @ingroup domi_proto | |||||
| /// @brief Reads the proto file in the text format. | |||||
| /// @param [in] file path of proto file | |||||
| /// @param [out] message Memory for storing the proto file | |||||
| /// @return true success | |||||
| /// @return false fail | |||||
| /// | |||||
| bool ReadProtoFromText(const char *file, google::protobuf::Message *message); | |||||
| bool ReadProtoFromMem(const char *data, int size, google::protobuf::Message *message); | |||||
| /// | |||||
| /// @brief get the Original Type of FrameworkOp | |||||
| /// @param [in] node | |||||
| /// @param [out] type | |||||
| /// @return Status | |||||
| /// | |||||
| domi::Status GetOriginalType(const ge::NodePtr &node, string &type); | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief Check whether the file path meets the whitelist verification requirements. | |||||
| /// @param [in] filePath file path | |||||
| /// @param [out] result | |||||
| /// | |||||
| bool ValidateStr(const std::string &filePath, const std::string &mode); | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief Obtains the current time string. | |||||
| /// @return Time character string in the format: %Y%m%d%H%M%S, eg: 20171011083555 | |||||
| /// | |||||
| std::string CurrentTimeInStr(); | |||||
| template <typename T, typename... Args> | |||||
| static inline std::shared_ptr<T> MakeShared(Args &&... args) { | |||||
| typedef typename std::remove_const<T>::type T_nc; | |||||
| std::shared_ptr<T> ret(new (std::nothrow) T_nc(std::forward<Args>(args)...)); | |||||
| return ret; | |||||
| } | |||||
| /// @ingroup math_util | |||||
| /// @brief check whether int64 multiplication can result in overflow | |||||
| /// @param [in] a multiplicator | |||||
| /// @param [in] b multiplicator | |||||
| /// @return Status | |||||
| inline domi::Status Int64MulCheckOverflow(int64_t a, int64_t b) { | |||||
| if (a > 0) { | |||||
| if (b > 0) { | |||||
| if (a > (INT64_MAX / b)) { | |||||
| return domi::FAILED; | |||||
| } | |||||
| } else { | |||||
| if (b < (INT64_MIN / a)) { | |||||
| return domi::FAILED; | |||||
| } | |||||
| } | |||||
| } else { | |||||
| if (b > 0) { | |||||
| if (a < (INT64_MIN / b)) { | |||||
| return domi::FAILED; | |||||
| } | |||||
| } else { | |||||
| if ((a != 0) && (b < (INT64_MAX / a))) { | |||||
| return domi::FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| return domi::SUCCESS; | |||||
| } | |||||
| /// @ingroup math_util | |||||
| /// @brief check whether int64 multiplication can result in overflow | |||||
| /// @param [in] a multiplicator | |||||
| /// @param [in] b multiplicator | |||||
| /// @return Status | |||||
| inline domi::Status CheckInt64Uint32MulOverflow(int64_t a, uint32_t b) { | |||||
| if (a == 0 || b == 0) { | |||||
| return domi::SUCCESS; | |||||
| } | |||||
| if (a > 0) { | |||||
| if (a > (INT64_MAX / b)) { | |||||
| return domi::FAILED; | |||||
| } | |||||
| } else { | |||||
| if (a < (INT64_MIN / b)) { | |||||
| return domi::FAILED; | |||||
| } | |||||
| } | |||||
| return domi::SUCCESS; | |||||
| } | |||||
| #define PARSER_INT64_MULCHECK(a, b) \ | |||||
| if (ge::parser::Int64MulCheckOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int64 %ld and %ld multiplication can result in overflow!", static_cast<int64_t>(a), \ | |||||
| static_cast<int64_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | |||||
| #define PARSER_INT64_UINT32_MULCHECK(a, b) \ | |||||
| if (ge::parser::CheckInt64Uint32MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int64 %ld and UINT32 %u multiplication can result in overflow!", static_cast<uint32_t>(a), \ | |||||
| static_cast<uint32_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | |||||
| } // namespace parser | |||||
| } // namespace ge | } // namespace ge | ||||
| /*lint --emacro((773),GE_TIMESTAMP_START)*/ | |||||
| /*lint -esym(773,GE_TIMESTAMP_START)*/ | |||||
| #define PARSER_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::parser::GetCurrentTimestamp() | |||||
| #define PARSER_TIMESTAMP_END(stage, stage_name) \ | |||||
| do { \ | |||||
| uint64_t endUsec_##stage = ge::parser::GetCurrentTimestamp(); \ | |||||
| GELOGI("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ | |||||
| (endUsec_##stage - startUsec_##stage)); \ | |||||
| } while (0); | |||||
| #define PARSER_TIMESTAMP_EVENT_END(stage, stage_name) \ | |||||
| do { \ | |||||
| uint64_t endUsec_##stage = ge::parser::GetCurrentTimestamp(); \ | |||||
| GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ | |||||
| (endUsec_##stage - startUsec_##stage)); \ | |||||
| } while (0); | |||||
| #endif // ACL_GRAPH_PARSE_UTIL_ | |||||
| #endif // ACL_GRAPH_PARSE_UTIL_ | |||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -115,7 +115,7 @@ void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescr | |||||
| case ProtobufFieldDescriptor::TYPE_FLOAT: | case ProtobufFieldDescriptor::TYPE_FLOAT: | ||||
| char str[kSignificantDigits]; | char str[kSignificantDigits]; | ||||
| if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1){ | |||||
| if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1) { | |||||
| json[field->name()] = str; | json[field->name()] = str; | ||||
| } else { | } else { | ||||
| json[field->name()] = reflection->GetFloat(message, field); | json[field->name()] = reflection->GetFloat(message, field); | ||||
| @@ -148,7 +148,7 @@ string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) { | |||||
| uint8_t *value = 0; | uint8_t *value = 0; | ||||
| value = reinterpret_cast<uint8_t *>(&temp_value); | value = reinterpret_cast<uint8_t *>(&temp_value); | ||||
| char str[kSignificantDigits]; | char str[kSignificantDigits]; | ||||
| if (sprintf_s(str, kSignificantDigits, "%d", *value) == -1){ | |||||
| if (sprintf_s(str, kSignificantDigits, "%d", *value) == -1) { | |||||
| GELOGW("Convert bytes to string fail, filed name:%s", field_name.c_str()); | GELOGW("Convert bytes to string fail, filed name:%s", field_name.c_str()); | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -17,8 +17,8 @@ | |||||
| // File: pb2json.h | // File: pb2json.h | ||||
| // Description: This header file for protobuf message and json interconversion | // Description: This header file for protobuf message and json interconversion | ||||
| #ifndef PARSER_COMMON_CONVERT_PB2JSON_H_ | |||||
| #define PARSER_COMMON_CONVERT_PB2JSON_H_ | |||||
| #ifndef GE_COMMON_CONVERT_PB2JSON_H_ | |||||
| #define GE_COMMON_CONVERT_PB2JSON_H_ | |||||
| #include <functional> | #include <functional> | ||||
| #include <memory> | #include <memory> | ||||
| #include <set> | #include <set> | ||||
| @@ -38,12 +38,12 @@ using ProtobufEnumValueDescriptor = ::google::protobuf::EnumValueDescriptor; | |||||
| class Pb2Json { | class Pb2Json { | ||||
| public: | public: | ||||
| /** | /** | ||||
| * @ingroup domi_omg | |||||
| * @brief Transfer protobuf object to JSON object | |||||
| * @param [out] json Converted JSON object | |||||
| * @return void success | |||||
| * @author | |||||
| */ | |||||
| * @ingroup domi_omg | |||||
| * @brief Transfer protobuf object to JSON object | |||||
| * @param [out] json Converted JSON object | |||||
| * @return void success | |||||
| * @author | |||||
| */ | |||||
| static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json, | static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json, | ||||
| bool enum2str = false); | bool enum2str = false); | ||||
| @@ -65,4 +65,4 @@ class Pb2Json { | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // PARSER_COMMON_CONVERT_PB2JSON_H_ | |||||
| #endif // GE_COMMON_CONVERT_PB2JSON_H_ | |||||
| @@ -18,7 +18,7 @@ | |||||
| #include <cstdlib> | #include <cstdlib> | ||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "common/math/math_util.h" | |||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "omg/omg.h" | #include "omg/omg.h" | ||||
| @@ -36,7 +36,7 @@ FMK_FUNC_HOST_VISIBILITY Status DataOpParser::ParseShape(const vector<int64_t> & | |||||
| GE_RETURN_WITH_LOG_IF_FALSE(op != nullptr, "ParseShape failed for data_op, op is null"); | GE_RETURN_WITH_LOG_IF_FALSE(op != nullptr, "ParseShape failed for data_op, op is null"); | ||||
| const string &data_op_name = op->GetName(); | const string &data_op_name = op->GetName(); | ||||
| GetParserContext().input_dims.emplace(data_op_name, shape); | |||||
| domi::GetContext().input_dims.emplace(data_op_name, shape); | |||||
| int64_t attr_type = 0; | int64_t attr_type = 0; | ||||
| ge::DataType data_type; | ge::DataType data_type; | ||||
| @@ -51,7 +51,7 @@ FMK_FUNC_HOST_VISIBILITY Status DataOpParser::ParseShape(const vector<int64_t> & | |||||
| ge::GeTensorDesc i_tensor_desc; | ge::GeTensorDesc i_tensor_desc; | ||||
| ge::GeTensorDesc o_tensor_desc; | ge::GeTensorDesc o_tensor_desc; | ||||
| const unordered_map<string, domiTensorFormat_t> &input_nodes_format_map = GetParserContext().input_nodes_format_map; | |||||
| const unordered_map<string, domiTensorFormat_t> &input_nodes_format_map = domi::GetContext().input_nodes_format_map; | |||||
| auto map_iter = input_nodes_format_map.find(data_op_name); | auto map_iter = input_nodes_format_map.find(data_op_name); | ||||
| if (map_iter != input_nodes_format_map.end() && map_iter->second == domi::DOMI_TENSOR_NC1HWC0) { | if (map_iter != input_nodes_format_map.end() && map_iter->second == domi::DOMI_TENSOR_NC1HWC0) { | ||||
| // Input 5D NC1HWC0 | // Input 5D NC1HWC0 | ||||
| @@ -80,9 +80,9 @@ FMK_FUNC_HOST_VISIBILITY Status DataOpParser::ParseShape(const vector<int64_t> & | |||||
| "Init ND Output Tensor failed"); | "Init ND Output Tensor failed"); | ||||
| } | } | ||||
| } | } | ||||
| i_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format)); | |||||
| i_tensor_desc.SetOriginFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format)); | |||||
| o_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(GetParserContext().format)); | |||||
| i_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(domi::GetContext().format)); | |||||
| i_tensor_desc.SetOriginFormat(ge::TypeUtils::DomiFormatToFormat(domi::GetContext().format)); | |||||
| o_tensor_desc.SetFormat(ge::TypeUtils::DomiFormatToFormat(domi::GetContext().format)); | |||||
| if (op->AddInputDesc(i_tensor_desc) != ge::GRAPH_SUCCESS) { | if (op->AddInputDesc(i_tensor_desc) != ge::GRAPH_SUCCESS) { | ||||
| GELOGE(domi::INTERNAL_ERROR, "AddInputDesc failed for op %s.", op->GetName().c_str()); | GELOGE(domi::INTERNAL_ERROR, "AddInputDesc failed for op %s.", op->GetName().c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -128,10 +128,10 @@ Status DataOpParser::InitNDTensor(const vector<int64_t> &shape, ge::DataType dat | |||||
| } | } | ||||
| uint32_t type_size = 0; | uint32_t type_size = 0; | ||||
| if (ge::TypeUtils::GetDataTypeLength(data_type, type_size)) { | if (ge::TypeUtils::GetDataTypeLength(data_type, type_size)) { | ||||
| PARSER_INT64_UINT32_MULCHECK(size, type_size); | |||||
| FMK_INT64_UINT32_MULCHECK(size, type_size); | |||||
| size *= type_size; | size *= type_size; | ||||
| } else { | } else { | ||||
| PARSER_INT64_UINT32_MULCHECK(size, static_cast<uint32_t>(sizeof(float))); | |||||
| FMK_INT64_UINT32_MULCHECK(size, static_cast<uint32_t>(sizeof(float))); | |||||
| size *= sizeof(float); | size *= sizeof(float); | ||||
| } | } | ||||
| ge::TensorUtils::SetSize(tensor_desc, size); | ge::TensorUtils::SetSize(tensor_desc, size); | ||||
| @@ -169,7 +169,7 @@ Status DataOpParser::InitInputTensor(const vector<int64_t> &shape, ge::GeTensorD | |||||
| if (input.GetShape().GetDim(0) != -1) { | if (input.GetShape().GetDim(0) != -1) { | ||||
| size = input.GetShape().GetShapeSize(); | size = input.GetShape().GetShapeSize(); | ||||
| } | } | ||||
| PARSER_INT64_UINT32_MULCHECK(size, static_cast<uint32_t>(sizeof(float))); | |||||
| FMK_INT64_UINT32_MULCHECK(size, static_cast<uint32_t>(sizeof(float))); | |||||
| ge::TensorUtils::SetSize(input, size * sizeof(float)); | ge::TensorUtils::SetSize(input, size * sizeof(float)); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "common/op/attr_value_util.h" | #include "common/op/attr_value_util.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "common/types.h" | |||||
| #include "omg/omg_inner_types.h" | #include "omg/omg_inner_types.h" | ||||
| #include "proto/om.pb.h" | #include "proto/om.pb.h" | ||||
| @@ -1,155 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <sys/stat.h> | |||||
| #include <fcntl.h> | |||||
| #include "parser/common/model_saver.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "mmpa/mmpa_api.h" | |||||
| namespace { | |||||
| const int kFileOpSuccess = 0; | |||||
| } // namespace | |||||
| namespace ge { | |||||
| namespace parser { | |||||
| const uint32_t kInteval = 2; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFile(const char *file_path, | |||||
| const Json &model) { | |||||
| Status ret = SUCCESS; | |||||
| if (file_path == nullptr || SUCCESS != CheckPath(file_path)) { | |||||
| GELOGE(FAILED, "Check output file failed."); | |||||
| return FAILED; | |||||
| } | |||||
| std::string model_str; | |||||
| try { | |||||
| model_str = model.dump(kInteval, ' ', false, Json::error_handler_t::ignore); | |||||
| } catch (std::exception &e) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19007", {"exception"}, {e.what()}); | |||||
| GELOGE(FAILED, "Failed to convert JSON to string, reason: %s.", e.what()); | |||||
| return FAILED; | |||||
| } catch (...) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19008"); | |||||
| GELOGE(FAILED, "Failed to convert JSON to string."); | |||||
| return FAILED; | |||||
| } | |||||
| char real_path[PATH_MAX] = {0}; | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(file_path) >= PATH_MAX, return FAILED, "file path is too long!"); | |||||
| if (realpath(file_path, real_path) == nullptr) { | |||||
| GELOGI("File %s does not exit, it will be created.", file_path); | |||||
| } | |||||
| // Open file | |||||
| mode_t mode = S_IRUSR | S_IWUSR; | |||||
| int32_t fd = mmOpen2(real_path, O_RDWR | O_CREAT | O_TRUNC, mode); | |||||
| if (fd == EN_ERROR || fd == EN_INVALID_PARAM) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file_path, strerror(errno)}); | |||||
| GELOGE(FAILED, "Open file[%s] failed. %s", file_path, strerror(errno)); | |||||
| return FAILED; | |||||
| } | |||||
| const char *model_char = model_str.c_str(); | |||||
| uint32_t len = static_cast<uint32_t>(model_str.length()); | |||||
| // Write data to file | |||||
| mmSsize_t mmpa_ret = mmWrite(fd, const_cast<void *>((const void *)model_char), len); | |||||
| if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E19004", {"file", "errmsg"}, {file_path, strerror(errno)}); | |||||
| // Need to both print the error info of mmWrite and mmClose, so return ret after mmClose | |||||
| GELOGE(FAILED, "Write to file failed. errno = %d, %s", mmpa_ret, strerror(errno)); | |||||
| ret = FAILED; | |||||
| } | |||||
| // Close file | |||||
| if (mmClose(fd) != EN_OK) { | |||||
| GELOGE(FAILED, "Close file failed."); | |||||
| ret = FAILED; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::CheckPath(const std::string &file_path) { | |||||
| // Determine file path length | |||||
| if (file_path.size() >= PATH_MAX) { | |||||
| GELOGE(FAILED, "Path is too long:%zu", file_path.size()); | |||||
| return FAILED; | |||||
| } | |||||
| // Find the last separator | |||||
| int path_split_pos = static_cast<int>(file_path.size() - 1); | |||||
| for (; path_split_pos >= 0; path_split_pos--) { | |||||
| if (file_path[path_split_pos] == '\\' || file_path[path_split_pos] == '/') { | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (path_split_pos == 0) { | |||||
| return SUCCESS; | |||||
| } | |||||
| // If there is a path before the file name, create the path | |||||
| if (path_split_pos != -1) { | |||||
| if (CreateDirectory(std::string(file_path).substr(0, static_cast<size_t>(path_split_pos))) != kFileOpSuccess) { | |||||
| GELOGE(FAILED, "CreateDirectory failed, file path:%s.", file_path.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int ModelSaver::CreateDirectory(const std::string &directory_path) { | |||||
| GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); | |||||
| auto dir_path_len = directory_path.length(); | |||||
| if (dir_path_len >= PATH_MAX) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E19002", {"filepath", "size"}, {directory_path, std::to_string(PATH_MAX)}); | |||||
| GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), PATH_MAX); | |||||
| return -1; | |||||
| } | |||||
| char tmp_dir_path[PATH_MAX] = {0}; | |||||
| for (size_t i = 0; i < dir_path_len; i++) { | |||||
| tmp_dir_path[i] = directory_path[i]; | |||||
| if ((tmp_dir_path[i] == '\\') || (tmp_dir_path[i] == '/')) { | |||||
| if (access(tmp_dir_path, F_OK) != 0) { | |||||
| int32_t ret = mmMkdir(tmp_dir_path, S_IRUSR | S_IWUSR | S_IXUSR); // 700 | |||||
| if (ret != 0) { | |||||
| if (errno != EEXIST) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); | |||||
| GELOGW("Can not create directory %s. Make sure the directory exists and writable.", | |||||
| directory_path.c_str()); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| int32_t ret = mmMkdir(const_cast<char *>(directory_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); // 700 | |||||
| if (ret != 0) { | |||||
| if (errno != EEXIST) { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); | |||||
| GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| return 0; | |||||
| } | |||||
| } // namespace parser | |||||
| } // namespace ge | |||||
| @@ -1,55 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_COMMON_FILE_SAVER_H_ | |||||
| #define PARSER_COMMON_FILE_SAVER_H_ | |||||
| #include <string> | |||||
| #include "ge/ge_api_error_codes.h" | |||||
| #include "register/register_types.h" | |||||
| #include "nlohmann/json.hpp" | |||||
| namespace ge { | |||||
| namespace parser { | |||||
| using Json = nlohmann::json; | |||||
| using std::string; | |||||
| class ModelSaver { | |||||
| public: | |||||
| /** | |||||
| * @ingroup domi_common | |||||
| * @brief Save JSON object to file | |||||
| * @param [in] file_path File output path | |||||
| * @param [in] model json object | |||||
| * @return Status result | |||||
| */ | |||||
| static Status SaveJsonToFile(const char *file_path, const Json &model); | |||||
| private: | |||||
| /// | |||||
| /// @ingroup domi_common | |||||
| /// @brief Check validity of the file path | |||||
| /// @return Status result | |||||
| /// | |||||
| static Status CheckPath(const string &file_path); | |||||
| static int CreateDirectory(const std::string &directory_path); | |||||
| }; | |||||
| } // namespace parser | |||||
| } // namespace ge | |||||
| #endif //PARSER_COMMON_FILE_SAVER_H_ | |||||
| @@ -18,40 +18,35 @@ COMMON_LOCAL_SRC_FILES := \ | |||||
| register_tbe.cc \ | register_tbe.cc \ | ||||
| parser_api.cc \ | parser_api.cc \ | ||||
| parser_inner_ctx.cc \ | parser_inner_ctx.cc \ | ||||
| acl_graph_parser_util.cc\ | |||||
| proto_file_parser.cc \ | proto_file_parser.cc \ | ||||
| acl_graph_parser_util.cc \ | |||||
| tbe_plugin_loader.cc \ | |||||
| model_saver.cc \ | |||||
| ../../graph/passes/pass_manager.cc \ | |||||
| ../../graph/common/omg_util.cc \ | |||||
| ../tensorflow/tensorflow_custom_parser_adapter.cc \ | ../tensorflow/tensorflow_custom_parser_adapter.cc \ | ||||
| ../tensorflow/tensorflow_fusion_custom_parser_adapter.cc \ | ../tensorflow/tensorflow_fusion_custom_parser_adapter.cc \ | ||||
| ../tensorflow/tensorflow_fusion_op_parser.cc \ | ../tensorflow/tensorflow_fusion_op_parser.cc \ | ||||
| ../tensorflow/tensorflow_util.cc \ | ../tensorflow/tensorflow_util.cc \ | ||||
| convert/pb2json.cc \ | |||||
| ../../common/convert/pb2json.cc \ | |||||
| op_def/ir_pb_converter.cc \ | op_def/ir_pb_converter.cc \ | ||||
| op_def/defs.cc \ | op_def/defs.cc \ | ||||
| op_def/op_schema.cc \ | op_def/op_schema.cc \ | ||||
| op_def/operator.cc \ | op_def/operator.cc \ | ||||
| op_map.cc \ | op_map.cc \ | ||||
| parser_types.cc \ | |||||
| pass_manager.cc \ | |||||
| parser_fp16_t.cc \ | |||||
| thread_pool.cc \ | |||||
| parser_utils.cc \ | |||||
| FMK_COMMON_SRC_FILES := \ | FMK_COMMON_SRC_FILES := \ | ||||
| # ../../common/fmk_error_codes.cc \ | |||||
| ../../common/auth/cipher.cc \ | |||||
| ../../common/context/ctx.cc \ | |||||
| ../../graph/passes/pass_manager.cc \ | |||||
| ../../graph/common/omg_util.cc \ | |||||
| ../../common/types.cc \ | ../../common/types.cc \ | ||||
| ../../common/auth/file_saver.cc \ | |||||
| ../../common/util.cc \ | ../../common/util.cc \ | ||||
| ../../common/model_saver.cc \ | ../../common/model_saver.cc \ | ||||
| ../../common/fmk_error_codes.cc \ | |||||
| ../../common/fp16_t.cc \ | ../../common/fp16_t.cc \ | ||||
| ../../common/thread_pool.cc \ | ../../common/thread_pool.cc \ | ||||
| ../../common/auth/file_saver.cc \ | |||||
| ../../common/auth/cipher.cc \ | |||||
| ../../common/context/ctx.cc \ | |||||
| LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) | LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) | ||||
| LOCAL_SRC_FILES += $(FMK_COMMON_SRC_FILES) | |||||
| #LOCAL_SRC_FILES += $(FMK_COMMON_SRC_FILES) | |||||
| LOCAL_C_INCLUDES := \ | LOCAL_C_INCLUDES := \ | ||||
| proto/om.proto \ | proto/om.proto \ | ||||
| @@ -73,10 +68,9 @@ LOCAL_C_INCLUDES := \ | |||||
| $(TOPDIR)inc/external/graph \ | $(TOPDIR)inc/external/graph \ | ||||
| $(TOPDIR)inc/framework \ | $(TOPDIR)inc/framework \ | ||||
| $(TOPDIR)inc/common/util \ | $(TOPDIR)inc/common/util \ | ||||
| $(TOPDIR)graphengine/ge \ | |||||
| $(TOPDIR)graphengine/ge/common \ | |||||
| $(TOPDIR)parser/parser \ | |||||
| $(TOPDIR)parser \ | |||||
| $(TOPDIR)framework/domi \ | |||||
| $(TOPDIR)framework/domi/common \ | |||||
| $(TOPDIR)framework/domi/parser \ | |||||
| $(TOPDIR)third_party/json/include \ | $(TOPDIR)third_party/json/include \ | ||||
| $(TOPDIR)third_party/protobuf/include \ | $(TOPDIR)third_party/protobuf/include \ | ||||
| libc_sec/include \ | libc_sec/include \ | ||||
| @@ -90,6 +84,7 @@ LOCAL_SHARED_LIBRARIES := \ | |||||
| libc_sec \ | libc_sec \ | ||||
| liberror_manager \ | liberror_manager \ | ||||
| libregister \ | libregister \ | ||||
| libge_common \ | |||||
| LOCAL_LDFLAGS := -lrt -ldl | LOCAL_LDFLAGS := -lrt -ldl | ||||
| @@ -18,7 +18,7 @@ | |||||
| #ifndef DOMI_OP_CONSTANT_OP_H_ | #ifndef DOMI_OP_CONSTANT_OP_H_ | ||||
| #define DOMI_OP_CONSTANT_OP_H_ | #define DOMI_OP_CONSTANT_OP_H_ | ||||
| #include "parser/common/op_def/operator.h" | #include "parser/common/op_def/operator.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "framework/common/types.h" | |||||
| namespace ge { | namespace ge { | ||||
| class ConstantOperator : public ParserOperator { | class ConstantOperator : public ParserOperator { | ||||
| @@ -23,7 +23,7 @@ | |||||
| #include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
| #include "graph/buffer.h" | #include "graph/buffer.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "framework/common/types.h" | |||||
| #include "framework/common/util.h" | #include "framework/common/util.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -98,7 +98,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status ConvertToOpDesc(co | |||||
| GE_CHK_BOOL_RET_STATUS(op.GetSchema(), domi::PARAM_INVALID, "Op schema is null, op type: %s", op.GetType().c_str()); | GE_CHK_BOOL_RET_STATUS(op.GetSchema(), domi::PARAM_INVALID, "Op schema is null, op type: %s", op.GetType().c_str()); | ||||
| op_def->SetName(op.GetName()); | op_def->SetName(op.GetName()); | ||||
| op_def->SetType(op.GetType()); | op_def->SetType(op.GetType()); | ||||
| GE_IF_BOOL_EXEC(op.GetType() == ge::parser::YOLO, op_def->SetType(ge::parser::REGION)); | |||||
| GE_IF_BOOL_EXEC(op.GetType() == ge::YOLO, op_def->SetType(ge::REGION)); | |||||
| UpdateTensorForOpDesc(op, op_def); | UpdateTensorForOpDesc(op, op_def); | ||||
| GELOGD("Convert to op desc: name:%s, input size: %zu, output size:%zu", op_def->GetName().c_str(), | GELOGD("Convert to op desc: name:%s, input size: %zu, output size:%zu", op_def->GetName().c_str(), | ||||
| @@ -18,7 +18,7 @@ | |||||
| #ifndef DOMI_OP_NO_OP_OP_H_ | #ifndef DOMI_OP_NO_OP_OP_H_ | ||||
| #define DOMI_OP_NO_OP_OP_H_ | #define DOMI_OP_NO_OP_OP_H_ | ||||
| #include "parser/common/op_def/operator.h" | #include "parser/common/op_def/operator.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "framework/common/types.h" | |||||
| namespace ge { | namespace ge { | ||||
| class NoOpOperator : public ParserOperator { | class NoOpOperator : public ParserOperator { | ||||
| @@ -18,7 +18,7 @@ | |||||
| #ifndef DOMI_OP_REF_SWITCH_H_ | #ifndef DOMI_OP_REF_SWITCH_H_ | ||||
| #define DOMI_OP_REF_SWITCH_H_ | #define DOMI_OP_REF_SWITCH_H_ | ||||
| #include "parser/common/op_def/operator.h" | #include "parser/common/op_def/operator.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "framework/common/types.h" | |||||
| namespace ge { | namespace ge { | ||||
| class RefSwitchOperator : public ParserOperator { | class RefSwitchOperator : public ParserOperator { | ||||
| @@ -17,7 +17,7 @@ | |||||
| // AUTO GEN PLEASE DO NOT MODIFY IT | // AUTO GEN PLEASE DO NOT MODIFY IT | ||||
| #include "common/op_def/shape_n_op.h" | #include "common/op_def/shape_n_op.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "framework/common/types.h" | |||||
| namespace ge { | namespace ge { | ||||
| FMK_FUNC_HOST_VISIBILITY ShapeNOperator::ShapeNOperator() : ParserOperator("ShapeN") {} | FMK_FUNC_HOST_VISIBILITY ShapeNOperator::ShapeNOperator() : ParserOperator("ShapeN") {} | ||||
| @@ -18,7 +18,7 @@ | |||||
| #ifndef DOMI_OP_SHAPE_N_OP_H_ | #ifndef DOMI_OP_SHAPE_N_OP_H_ | ||||
| #define DOMI_OP_SHAPE_N_OP_H_ | #define DOMI_OP_SHAPE_N_OP_H_ | ||||
| #include "parser/common/op_def/operator.h" | #include "parser/common/op_def/operator.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "framework/common/types.h" | |||||
| namespace ge { | namespace ge { | ||||
| class ShapeNOperator : public ParserOperator { | class ShapeNOperator : public ParserOperator { | ||||
| @@ -20,7 +20,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| namespace ge { | namespace ge { | ||||
| VarIsInitializedOpOperator::VarIsInitializedOpOperator() : ParserOperator(ge::parser::VARISINITIALIZEDOP) {} | |||||
| VarIsInitializedOpOperator::VarIsInitializedOpOperator() : ParserOperator(ge::VARISINITIALIZEDOP) {} | |||||
| VarIsInitializedOpOperator::~VarIsInitializedOpOperator() {} | VarIsInitializedOpOperator::~VarIsInitializedOpOperator() {} | ||||
| @@ -18,7 +18,7 @@ | |||||
| #ifndef DOMI_OP_VARISINITIALIZEDOP_H_ | #ifndef DOMI_OP_VARISINITIALIZEDOP_H_ | ||||
| #define DOMI_OP_VARISINITIALIZEDOP_H_ | #define DOMI_OP_VARISINITIALIZEDOP_H_ | ||||
| #include "parser/common/op_def/operator.h" | #include "parser/common/op_def/operator.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "framework/common/types.h" | |||||
| namespace ge { | namespace ge { | ||||
| class VarIsInitializedOpOperator : public ParserOperator { | class VarIsInitializedOpOperator : public ParserOperator { | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| namespace ge { | namespace ge { | ||||
| VariableOperator::VariableOperator() : ParserOperator(ge::parser::VARIABLE) {} | |||||
| VariableOperator::VariableOperator() : ParserOperator(ge::VARIABLE) {} | |||||
| VariableOperator::~VariableOperator() {} | VariableOperator::~VariableOperator() {} | ||||
| @@ -19,7 +19,7 @@ | |||||
| #define DOMI_OP_VARIABLE_H_ | #define DOMI_OP_VARIABLE_H_ | ||||
| #include <vector> | #include <vector> | ||||
| #include "parser/common/op_def/operator.h" | #include "parser/common/op_def/operator.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "framework/common/types.h" | |||||
| namespace ge { | namespace ge { | ||||
| class VariableOperator : public ParserOperator { | class VariableOperator : public ParserOperator { | ||||
| @@ -20,13 +20,12 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "framework/common/types.h" | |||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| using std::map; | using std::map; | ||||
| using std::string; | using std::string; | ||||
| using std::vector; | using std::vector; | ||||
| using namespace ge::parser; | |||||
| namespace ge { | namespace ge { | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map<std::string, std::string> caffe_op_map = { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map<std::string, std::string> caffe_op_map = { | ||||
| @@ -98,7 +97,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY map<string, string> tensorflow_ | |||||
| {"VarHandleOp", VARHANDLEOP}, | {"VarHandleOp", VARHANDLEOP}, | ||||
| {"VarIsInitializedOp", VARISINITIALIZEDOP}, | {"VarIsInitializedOp", VARISINITIALIZEDOP}, | ||||
| {"IsVariableInitialized", ISVARIABLEINITIALIZED}, | {"IsVariableInitialized", ISVARIABLEINITIALIZED}, | ||||
| {"ReadVariableOp", READVARIABLEOP}, | |||||
| {"Reshape", RESHAPE}, | {"Reshape", RESHAPE}, | ||||
| {"Squeeze", SQUEEZE}, | {"Squeeze", SQUEEZE}, | ||||
| {"NoOp", NOOP}, | {"NoOp", NOOP}, | ||||
| @@ -23,8 +23,8 @@ | |||||
| #include <mutex> | #include <mutex> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "common/types.h" | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "omg/omg_inner_types.h" | #include "omg/omg_inner_types.h" | ||||
| #include "external/register/register.h" | #include "external/register/register.h" | ||||
| @@ -162,7 +162,7 @@ class CustomParserAdapterRegistrar { | |||||
| */ | */ | ||||
| #define REGISTER_OP_PARSER_CREATOR(framework, op_type, clazz) \ | #define REGISTER_OP_PARSER_CREATOR(framework, op_type, clazz) \ | ||||
| std::shared_ptr<OpParser> Creator_##framework##_##op_type##_Op_Parser() { \ | std::shared_ptr<OpParser> Creator_##framework##_##op_type##_Op_Parser() { \ | ||||
| std::shared_ptr<clazz> ptr = ge::parser::MakeShared<clazz>(); \ | |||||
| std::shared_ptr<clazz> ptr = ge::MakeShared<clazz>(); \ | |||||
| if (ptr == nullptr) { \ | if (ptr == nullptr) { \ | ||||
| GELOGW("MakeShared failed, result is nullptr."); \ | GELOGW("MakeShared failed, result is nullptr."); \ | ||||
| } \ | } \ | ||||
| @@ -173,7 +173,7 @@ class CustomParserAdapterRegistrar { | |||||
| #define REGISTER_FUSION_OP_PARSER_CREATOR(framework, op_type, clazz) \ | #define REGISTER_FUSION_OP_PARSER_CREATOR(framework, op_type, clazz) \ | ||||
| std::shared_ptr<OpParser> Creator_##framework##_##op_type##_Fusion_Op_Parser() { \ | std::shared_ptr<OpParser> Creator_##framework##_##op_type##_Fusion_Op_Parser() { \ | ||||
| std::shared_ptr<clazz> ptr = ge::parser::MakeShared<clazz>(); \ | |||||
| std::shared_ptr<clazz> ptr = ge::MakeShared<clazz>(); \ | |||||
| if (ptr == nullptr) { \ | if (ptr == nullptr) { \ | ||||
| GELOGW("MakeShared failed, result is nullptr."); \ | GELOGW("MakeShared failed, result is nullptr."); \ | ||||
| } \ | } \ | ||||
| @@ -187,7 +187,7 @@ class CustomParserAdapterRegistrar { | |||||
| /// @param [in] clazz CaffeCustomParserAdapter adaptation class | /// @param [in] clazz CaffeCustomParserAdapter adaptation class | ||||
| #define REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(framework, clazz) \ | #define REGISTER_CUSTOM_PARSER_ADAPTER_CREATOR(framework, clazz) \ | ||||
| std::shared_ptr<OpParser> Creator_##framework##_Op_Parser_Adapter() { \ | std::shared_ptr<OpParser> Creator_##framework##_Op_Parser_Adapter() { \ | ||||
| std::shared_ptr<clazz> ptr = ge::parser::MakeShared<clazz>(); \ | |||||
| std::shared_ptr<clazz> ptr = ge::MakeShared<clazz>(); \ | |||||
| if (ptr == nullptr) { \ | if (ptr == nullptr) { \ | ||||
| GELOGW("MakeShared failed, result is nullptr."); \ | GELOGW("MakeShared failed, result is nullptr."); \ | ||||
| } \ | } \ | ||||
| @@ -17,7 +17,7 @@ | |||||
| #include "framework/omg/parser/parser_api.h" | #include "framework/omg/parser/parser_api.h" | ||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "tbe_plugin_loader.h" | |||||
| #include "common/ge/tbe_plugin_manager.h" | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "parser/common/register_tbe.h" | #include "parser/common/register_tbe.h" | ||||
| #include "framework/omg/parser/parser_inner_ctx.h" | #include "framework/omg/parser/parser_inner_ctx.h" | ||||
| @@ -36,7 +36,7 @@ Status ParserInitialize(const std::map<std::string, std::string> &options) { | |||||
| } | } | ||||
| // load custom op plugin | // load custom op plugin | ||||
| TBEPluginLoader::Instance().LoadPluginSo(options); | |||||
| TBEPluginManager::Instance().LoadPluginSo(options); | |||||
| std::vector<OpRegistrationData> registrationDatas = domi::OpRegistry::Instance()->registrationDatas; | std::vector<OpRegistrationData> registrationDatas = domi::OpRegistry::Instance()->registrationDatas; | ||||
| GELOGI("The size of registrationDatas in parser is: %zu", registrationDatas.size()); | GELOGI("The size of registrationDatas in parser is: %zu", registrationDatas.size()); | ||||
| @@ -67,7 +67,7 @@ Status ParserFinalize() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| GE_CHK_STATUS(TBEPluginLoader::Instance().Finalize()); | |||||
| GE_CHK_STATUS(TBEPluginManager::Instance().Finalize()); | |||||
| if (parser_initialized) { | if (parser_initialized) { | ||||
| parser_initialized = false; | parser_initialized = false; | ||||
| } | } | ||||
| @@ -1,653 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_COMMON_FP16_T_H_ | |||||
| #define PARSER_COMMON_FP16_T_H_ | |||||
| #include <algorithm> | |||||
| #include <cmath> | |||||
| #include <cstdint> | |||||
| namespace ge { | |||||
| namespace parser { | |||||
| using DimIndex = enum { | |||||
| kDim0 = 0, | |||||
| kDim1, | |||||
| kDim2, | |||||
| kDim3, | |||||
| kDim4, | |||||
| kDim5, | |||||
| kDim6, | |||||
| kDim7, | |||||
| kDim8, | |||||
| kDim9, | |||||
| kDim10, | |||||
| kDim11, | |||||
| kDim12, | |||||
| kDim13, | |||||
| kDim14, | |||||
| kDim15, | |||||
| kDim16, | |||||
| }; | |||||
| using BitShift = enum { | |||||
| kBitShift2 = 2, | |||||
| kBitShift3 = 3, | |||||
| kBitShift4 = 4, | |||||
| kBitShift5 = 5, | |||||
| kBitShift6 = 6, | |||||
| kBitShift7 = 7, | |||||
| kBitShift8 = 8, | |||||
| kBitShift9 = 9, | |||||
| kBitShift10 = 10, | |||||
| kBitShift11 = 11, | |||||
| kBitShift12 = 12, | |||||
| kBitShift13 = 13, | |||||
| kBitShift14 = 14, | |||||
| kBitShift15 = 15, | |||||
| kBitShift16 = 16, | |||||
| kBitShift20 = 20, | |||||
| kBitShift24 = 24, | |||||
| kBitShift27 = 27, | |||||
| kBitShift28 = 28, | |||||
| kBitShift31 = 31, | |||||
| kBitShift32 = 32, | |||||
| kBitShift36 = 36, | |||||
| kBitShift40 = 40, | |||||
| kBitShift44 = 44, | |||||
| kBitShift48 = 48, | |||||
| kBitShift52 = 52, | |||||
| kBitShift56 = 56, | |||||
| kBitShift59 = 59, | |||||
| kBitShift60 = 60, | |||||
| kBitShift63 = 63, | |||||
| kBitShift64 = 64, | |||||
| kBitShift128 = 128, | |||||
| kBitShift255 = 255, | |||||
| kBitShift256 = 256, | |||||
| kBitShift512 = 512, | |||||
| kBitShift768 = 768, | |||||
| kBitShift784 = 784, | |||||
| kBitShift1020 = 1020, | |||||
| kBitShift1024 = 1024, | |||||
| kBitShift3136 = 3136, | |||||
| kBitShift4096 = 4096, | |||||
| kBitShift6144 = 6144, | |||||
| kBitShift10240 = 10240, | |||||
| kBitShift65536 = 65536 | |||||
| }; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief fp16 exponent bias | |||||
| constexpr uint16_t kFp16ExpBias = 15; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief the exponent bit length of fp16 is 5 | |||||
| constexpr uint16_t kFp16ExpLen = 5; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief the mantissa bit length of fp16 is 10 | |||||
| constexpr uint16_t kFp16ManLen = 10; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief bit index of sign in fp16 | |||||
| constexpr uint16_t kFp16SignIndex = 15; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief sign mask of fp16 (1 00000 00000 00000) | |||||
| constexpr uint16_t kFp16SignMask = 0x8000; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief exponent mask of fp16 ( 11111 00000 00000) | |||||
| constexpr uint16_t kFp16ExpMask = 0x7C00; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief mantissa mask of fp16 ( 11111 11111) | |||||
| constexpr uint16_t kFp16ManMask = 0x03FF; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief hide bit of mantissa of fp16( 1 00000 00000) | |||||
| constexpr uint16_t kFp16ManHideBit = 0x0400; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief maximum value (0111 1011 1111 1111) | |||||
| constexpr uint16_t kFp16Max = 0x7BFF; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief minimum value (1111 1011 1111 1111) | |||||
| constexpr uint16_t kFp16Min = 0xFBFF; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief absolute maximum value (0111 1111 1111 1111) | |||||
| constexpr uint16_t kFp16AbsMax = 0x7FFF; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief maximum exponent value of fp16 is 15(11111) | |||||
| constexpr uint16_t kFp16MaxExp = 0x001F; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief maximum valid exponent value of fp16 is 14(11110) | |||||
| constexpr uint16_t kFp16MaxValidExp = 0x001E; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief maximum mantissa value of fp16(11111 11111) | |||||
| constexpr uint16_t kFp16MaxMan = 0x03FF; | |||||
| /// @ingroup fp16 basic parameter | |||||
| /// @brief absolute minimum normal value of fp16 | |||||
| /// (E=1,M=0 D=2^(-14)=0.00006103515625) | |||||
| constexpr uint16_t kFp16MinNormal = 1.0f / (2 << 14); | |||||
| /// @ingroup fp16 basic operator | |||||
| /// @brief get sign of fp16 | |||||
| #define FP16_EXTRAC_SIGN(x) (((x) >> 15) & 1) | |||||
| /// @ingroup fp16 basic operator | |||||
| /// @brief get exponent of fp16 | |||||
| #define FP16_EXTRAC_EXP(x) (((x) >> 10) & kFp16MaxExp) | |||||
| /// @ingroup fp16 basic operator | |||||
| /// @brief get mantissa of fp16 | |||||
| #define FP16_EXTRAC_MAN(x) ((((x) >> 0) & 0x3FF) | (((((x) >> 10) & 0x1F) > 0 ? 1 : 0) * 0x400)) | |||||
| /// @ingroup fp16 basic operator | |||||
| /// @brief constructor of fp16 from sign exponent and mantissa | |||||
| #define FP16_CONSTRUCTOR(s, e, m) (((s) << kFp16SignIndex) | ((e) << kFp16ManLen) | ((m)&kFp16MaxMan)) | |||||
| /// @ingroup fp16 special value judgment | |||||
| /// @brief whether a fp16 is zero | |||||
| #define FP16_IS_ZERO(x) (((x)&kFp16AbsMax) == 0) | |||||
| /// @ingroup fp16 special value judgment | |||||
| /// @brief whether a fp16 is a denormalized value | |||||
| #define FP16_IS_DENORM(x) ((((x)&kFp16ExpMask) == 0)) | |||||
| /// @ingroup fp16 special value judgment | |||||
| /// @brief whether a fp16 is infinite | |||||
| #define FP16_IS_INF(x) (((x)&kFp16AbsMax) == kFp16ExpMask) | |||||
| /// @ingroup fp16 special value judgment | |||||
| /// @brief whether a fp16 is NaN | |||||
| #define FP16_IS_NAN(x) (((x & kFp16ExpMask) == kFp16ExpMask) && (x & kFp16ManMask)) | |||||
| /// @ingroup fp16 special value judgment | |||||
| /// @brief whether a fp16 is invalid | |||||
| #define FP16_IS_INVALID(x) ((x & kFp16ExpMask) == kFp16ExpMask) | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief fp32 exponent bias | |||||
| constexpr uint16_t kFp32ExpBias = 127; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief the exponent bit length of float/fp32 is 8 | |||||
| constexpr uint16_t kFp32ExpLen = 8; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief the mantissa bit length of float/fp32 is 23 | |||||
| constexpr uint16_t kFp32ManLen = 23; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief bit index of sign in float/fp32 | |||||
| constexpr uint16_t kFp32SignIndex = 31; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief sign mask of fp32 (1 0000 0000 0000 0000 0000 0000 000) | |||||
| constexpr uint32_t kFp32SignMask = 0x80000000u; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief exponent mask of fp32 ( 1111 1111 0000 0000 0000 0000 000) | |||||
| constexpr uint32_t kFp32ExpMask = 0x7F800000u; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief mantissa mask of fp32 ( 1111 1111 1111 1111 111) | |||||
| constexpr uint32_t kFp32ManMask = 0x007FFFFFu; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief hide bit of mantissa of fp32 ( 1 0000 0000 0000 0000 000) | |||||
| constexpr uint32_t kFp32ManHideBit = 0x00800000u; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief absolute maximum value (0 1111 1111 1111 1111 1111 1111 111) | |||||
| constexpr uint32_t kFp32AbsMax = 0x7FFFFFFFu; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief maximum exponent value of fp32 is 255(1111 1111) | |||||
| constexpr uint32_t kFp32MaxExp = 0xFF; | |||||
| /// @ingroup fp32 basic parameter | |||||
| /// @brief maximum mantissa value of fp32 (1111 1111 1111 1111 1111 111) | |||||
| constexpr uint32_t kFp32MaxMan = 0x7FFFFF; | |||||
| /// @ingroup fp32 special value judgment | |||||
| /// @brief whether a fp32 is NaN | |||||
| #define FP32_IS_NAN(x) (((x & kFp32ExpMask) == kFp32ExpMask) && (x & kFp32ManMask)) | |||||
| /// @ingroup fp32 special value judgment | |||||
| /// @brief whether a fp32 is infinite | |||||
| #define FP32_IS_INF(x) (((x & kFp32ExpMask) == kFp32ExpMask) && (!(x & kFp32ManMask))) | |||||
| /// @ingroup fp32 special value judgment | |||||
| /// @brief whether a fp32 is a denormalized value | |||||
| #define FP32_IS_DENORM(x) ((((x)&kFp32ExpMask) == 0)) | |||||
| /// @ingroup fp32 basic operator | |||||
| /// @brief get sign of fp32 | |||||
| #define FP32_EXTRAC_SIGN(x) (((x) >> kFp32SignIndex) & 1) | |||||
| /// @ingroup fp32 basic operator | |||||
| /// @brief get exponent of fp16 | |||||
| #define FP32_EXTRAC_EXP(x) (((x)&kFp32ExpMask) >> kFp32ManLen) | |||||
| /// @ingroup fp32 basic operator | |||||
| /// @brief get mantissa of fp16 | |||||
| #define FP32_EXTRAC_MAN(x) (((x)&kFp32ManMask) | (((((x) >> kFp32ManLen) & kFp32MaxExp) > 0 ? 1 : 0) * kFp32ManHideBit)) | |||||
| /// @ingroup fp32 basic operator | |||||
| /// @brief constructor of fp32 from sign exponent and mantissa | |||||
| #define FP32_CONSTRUCTOR(s, e, m) (((s) << kFp32SignIndex) | ((e) << kFp32ManLen) | ((m)&kFp32MaxMan)) | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief fp64 exponent bias | |||||
| constexpr uint16_t kFp64ExpBias = 1023; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief the exponent bit length of double/fp64 is 11 | |||||
| constexpr uint16_t kFp64ExpLen = 11; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief the mantissa bit length of double/fp64 is 52 | |||||
| constexpr uint16_t kFp64ManLen = 52; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief bit index of sign in double/fp64 is 63 | |||||
| constexpr uint16_t kFp64SignIndex = 63; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief sign mask of fp64 (1 000 (total 63bits 0)) | |||||
| constexpr uint64_t kFp64SignMask = 0x8000000000000000LLu; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief exponent mask of fp64 (0 1 11111 11111 0000?-?-(total 52bits 0)) | |||||
| constexpr uint64_t kFp64ExpMask = 0x7FF0000000000000LLu; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief mantissa mask of fp64 ( 1111?-?-(total 52bits 1)) | |||||
| constexpr uint64_t kFp64ManMask = 0x000FFFFFFFFFFFFFLLu; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief hide bit of mantissa of fp64 ( 1 0000?-?-(total 52bits 0)) | |||||
| constexpr uint64_t kFp64ManHideBit = 0x0010000000000000LLu; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief absolute maximum value (0 111?-?-(total 63bits 1)) | |||||
| constexpr uint64_t kFp64AbsMax = 0x7FFFFFFFFFFFFFFFLLu; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief maximum exponent value of fp64 is 2047(1 11111 11111) | |||||
| constexpr uint64_t kFp64MaxExp = 0x07FF; | |||||
| /// @ingroup fp64 basic parameter | |||||
| /// @brief maximum mantissa value of fp64 (111?-?-(total 52bits 1)) | |||||
| constexpr uint64_t kFp64MaxMan = 0xFFFFFFFFFFFLLu; | |||||
| /// @ingroup fp64 special value judgment | |||||
| /// @brief whether a fp64 is NaN | |||||
| #define FP64_IS_NAN(x) (((x & kFp64ExpMask) == kFp64ExpMask) && (x & kFp64ManMask)) | |||||
| /// @ingroup fp64 special value judgment | |||||
| /// @brief whether a fp64 is infinite | |||||
| #define FP64_IS_INF(x) (((x & kFp64ExpMask) == kFp64ExpMask) && (!(x & kFp64ManMask))) | |||||
| /// @ingroup integer special value judgment | |||||
| /// @brief maximum positive value of int8_t (0111 1111) | |||||
| constexpr int8_t kInt8Max = 0x7F; | |||||
| /// @ingroup integer special value judgment | |||||
| /// @brief maximum value of a data with 8 bits length (1111 111) | |||||
| constexpr uint8_t kBitLen8Max = 0xFF; | |||||
| /// @ingroup integer special value judgment | |||||
| /// @brief maximum positive value of int16_t (0111 1111 1111 1111) | |||||
| constexpr int16_t kInt16Max = 0x7FFF; | |||||
| /// @ingroup integer special value judgment | |||||
| /// @brief maximum value of a data with 16 bits length (1111 1111 1111 1111) | |||||
| constexpr uint16_t kBitLen16Max = 0xFFFF; | |||||
| /// @ingroup integer special value judgment | |||||
| /// @brief maximum positive value of int32_t (0111 1111 1111 1111 1111 1111 1111 1111) | |||||
| constexpr int32_t kInt32Max = 0x7FFFFFFFu; | |||||
| /// @ingroup integer special value judgment | |||||
| /// @brief maximum value of a data with 32 bits length (1111 1111 1111 1111 1111 1111 1111 1111) | |||||
| constexpr uint32_t kBitLen32Max = 0xFFFFFFFFu; | |||||
| /// @ingroup integer special value judgment | |||||
| /// @brief maximum positive value of int64_t | |||||
| /// (0111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111) | |||||
| constexpr int64_t kInt64Max = 0x7FFFFFFFFFFFFFFFu; | |||||
| /// @ingroup integer special value judgment | |||||
| /// @brief maximum value of a data with 64 bits length | |||||
| /// (1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111 1111) | |||||
| constexpr uint64_t kBitLen64Max = 0xFFFFFFFFFFFFFFFFu; | |||||
| /// @ingroup fp16_t enum | |||||
| /// @brief round mode of last valid digital | |||||
| enum TagFp16RoundMode { | |||||
| kRoundToNearest = 0, // < round to nearest even | |||||
| kRoundByTruncated, // < round by truncated | |||||
| kRoundModeReserved, | |||||
| }; | |||||
| /// @ingroup fp16_t | |||||
| /// @brief Half precision float | |||||
| /// bit15: 1 bit SIGN +---+-----+------------+ | |||||
| /// bit14-10: 5 bit EXP | S |EEEEE|MM MMMM MMMM| | |||||
| /// bit0-9: 10bit MAN +---+-----+------------+ | |||||
| using fp16_t = struct TagFp16 { | |||||
| uint16_t val; | |||||
| public: | |||||
| /// @ingroup fp16_t constructor | |||||
| /// @brief Constructor without any param(default constructor) | |||||
| TagFp16(void) { val = 0x0u; } | |||||
| /// @ingroup fp16_t constructor | |||||
| /// @brief Constructor with an uint16_t value | |||||
| TagFp16(const uint16_t &ui_val) : val(ui_val) {} | |||||
| /// @ingroup fp16_t constructor | |||||
| /// @brief Constructor with a fp16_t object(copy constructor) | |||||
| TagFp16(const TagFp16 &fp) : val(fp.val) {} | |||||
| /// @ingroup fp16_t math operator | |||||
| /// @param [in] fp fp16_t object to be added | |||||
| /// @brief Override addition operator to performing fp16_t addition | |||||
| /// @return Return fp16_t result of adding this and fp | |||||
| TagFp16 operator+(const TagFp16 fp); | |||||
| /// @ingroup fp16_t math operator | |||||
| /// @param [in] fp fp16_t object to be subtracted | |||||
| /// @brief Override addition operator to performing fp16_t subtraction | |||||
| /// @return Return fp16_t result of subtraction fp from this | |||||
| TagFp16 operator-(const TagFp16 fp); | |||||
| /// @ingroup fp16_t math operator | |||||
| /// @param [in] fp fp16_t object to be multiplied | |||||
| /// @brief Override multiplication operator to performing fp16_t multiplication | |||||
| /// @return Return fp16_t result of multiplying this and fp | |||||
| TagFp16 operator*(const TagFp16 fp); | |||||
| /// @ingroup fp16_t math operator divided | |||||
| /// @param [in] fp fp16_t object to be divided | |||||
| /// @brief Override division operator to performing fp16_t division | |||||
| /// @return Return fp16_t result of division this by fp | |||||
| TagFp16 operator/(const TagFp16 fp); | |||||
| /// @ingroup fp16_t math operator | |||||
| /// @param [in] fp fp16_t object to be added | |||||
| /// @brief Override addition operator to performing fp16_t addition | |||||
| /// @return Return fp16_t result of adding this and fp | |||||
| TagFp16 operator+=(const TagFp16 fp); | |||||
| /// @ingroup fp16_t math operator | |||||
| /// @param [in] fp fp16_t object to be subtracted | |||||
| /// @brief Override addition operator to performing fp16_t subtraction | |||||
| /// @return Return fp16_t result of subtraction fp from this | |||||
| TagFp16 operator-=(const TagFp16 fp); | |||||
| /// @ingroup fp16_t math operator | |||||
| /// @param [in] fp fp16_t object to be multiplied | |||||
| /// @brief Override multiplication operator to performing fp16_t multiplication | |||||
| /// @return Return fp16_t result of multiplying this and fp | |||||
| TagFp16 operator*=(const TagFp16 fp); | |||||
| /// @ingroup fp16_t math operator divided | |||||
| /// @param [in] fp fp16_t object to be divided | |||||
| /// @brief Override division operator to performing fp16_t division | |||||
| /// @return Return fp16_t result of division this by fp | |||||
| TagFp16 operator/=(const TagFp16 fp); | |||||
| /// @ingroup fp16_t math compare operator | |||||
| /// @param [in] fp fp16_t object to be compared | |||||
| /// @brief Override basic comparison operator to performing fp16_t if-equal comparison | |||||
| /// @return Return boolean result of if-equal comparison of this and fp. | |||||
| bool operator==(const TagFp16 &fp) const; | |||||
| /// @ingroup fp16_t math compare operator | |||||
| /// @param [in] fp fp16_t object to be compared | |||||
| /// @brief Override basic comparison operator to performing fp16_t not-equal comparison | |||||
| /// @return Return boolean result of not-equal comparison of this and fp. | |||||
| bool operator!=(const TagFp16 &fp) const; | |||||
| /// @ingroup fp16_t math compare operator | |||||
| /// @param [in] fp fp16_t object to be compared | |||||
| /// @brief Override basic comparison operator to performing fp16_t greater-than comparison | |||||
| /// @return Return boolean result of greater-than comparison of this and fp. | |||||
| bool operator>(const TagFp16 &fp) const; | |||||
| /// @ingroup fp16_t math compare operator | |||||
| /// @param [in] fp fp16_t object to be compared | |||||
| /// @brief Override basic comparison operator to performing fp16_t greater-equal comparison | |||||
| /// @return Return boolean result of greater-equal comparison of this and fp. | |||||
| bool operator>=(const TagFp16 &fp) const; | |||||
| /// @ingroup fp16_t math compare operator | |||||
| /// @param [in] fp fp16_t object to be compared | |||||
| /// @brief Override basic comparison operator to performing fp16_t less-than comparison | |||||
| /// @return Return boolean result of less-than comparison of this and fp. | |||||
| bool operator<(const TagFp16 &fp) const; | |||||
| /// @ingroup fp16_t math compare operator | |||||
| /// @param [in] fp fp16_t object to be compared | |||||
| /// @brief Override basic comparison operator to performing fp16_t less-equal comparison | |||||
| /// @return Return boolean result of less-equal comparison of this and fp. | |||||
| bool operator<=(const TagFp16 &fp) const; | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] fp fp16_t object to be copy to fp16_t | |||||
| /// @brief Override basic evaluation operator to copy fp16_t to a new fp16_t | |||||
| /// @return Return fp16_t result from fp | |||||
| TagFp16 &operator=(const TagFp16 &fp); | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] f_val float object to be converted to fp16_t | |||||
| /// @brief Override basic evaluation operator to convert float to fp16_t | |||||
| /// @return Return fp16_t result from f_val | |||||
| TagFp16 &operator=(const float &f_val); | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] d_val double object to be converted to fp16_t | |||||
| /// @brief Override basic evaluation operator to convert double to fp16_t | |||||
| /// @return Return fp16_t result from d_val | |||||
| TagFp16 &operator=(const double &d_val); | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] i_val float object to be converted to fp16_t | |||||
| /// @brief Override basic evaluation operator to convert float to fp16_t | |||||
| /// @return Return fp16_t result from i_val | |||||
| TagFp16 &operator=(const int8_t &i_val); | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] ui_val uint8_t object to be converted to fp16_t | |||||
| /// @brief Override basic evaluation operator to convert uint8_t to fp16_t | |||||
| /// @return Return fp16_t result from ui_val | |||||
| TagFp16 &operator=(const uint8_t &ui_val); | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] i_val int16_t object to be converted to fp16_t | |||||
| /// @brief Override basic evaluation operator to convert int16_t to fp16_t | |||||
| /// @return Return fp16_t result from i_val | |||||
| TagFp16 &operator=(const int16_t &i_val); | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] ui_val uint16_t object to be converted to fp16_t | |||||
| /// @brief Override basic evaluation operator to convert uint16_t to fp16_t | |||||
| /// @return Return fp16_t result from ui_val | |||||
| TagFp16 &operator=(const uint16_t &ui_val); | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] i_val int32_t object to be converted to fp16_t | |||||
| /// @brief Override basic evaluation operator to convert int32_t to fp16_t | |||||
| /// @return Return fp16_t result from i_val | |||||
| TagFp16 &operator=(const int32_t &i_val); | |||||
| /// @ingroup fp16_t math evaluation operator | |||||
| /// @param [in] ui_val uint32_t object to be converted to fp16_t | |||||
| /// @brief Override basic evaluation operator to convert uint32_t to fp16_t | |||||
| /// @return Return fp16_t result from ui_val | |||||
| TagFp16 &operator=(const uint32_t &ui_val); | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to float/fp32 | |||||
| /// @return Return float/fp32 value of fp16_t | |||||
| operator float() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to double/fp64 | |||||
| /// @return Return double/fp64 value of fp16_t | |||||
| operator double() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to int8_t | |||||
| /// @return Return int8_t value of fp16_t | |||||
| operator int8_t() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to uint8_t | |||||
| /// @return Return uint8_t value of fp16_t | |||||
| operator uint8_t() const; | |||||
| /// @ingroup fp16_t conversion | |||||
| /// @brief Override convert operator to convert fp16_t to int16_t | |||||
| /// @return Return int16_t value of fp16_t | |||||
| operator int16_t() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to uint16_t | |||||
| /// @return Return uint16_t value of fp16_t | |||||
| operator uint16_t() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to int32_t | |||||
| /// @return Return int32_t value of fp16_t | |||||
| operator int32_t() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to uint32_t | |||||
| /// @return Return uint32_t value of fp16_t | |||||
| operator uint32_t() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to int64_t | |||||
| /// @return Return int64_t value of fp16_t | |||||
| operator int64_t() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Override convert operator to convert fp16_t to uint64_t | |||||
| /// @return Return uint64_t value of fp16_t | |||||
| operator uint64_t() const; | |||||
| /// @ingroup fp16_t judgment method | |||||
| /// @param [in] fp fp16_t object to be judgement | |||||
| /// @brief whether a fp16_t is inifinite | |||||
| /// @return Returns 1:+INF -1:-INF 0:not INF | |||||
| int IsInf(); | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Convert fp16_t to float/fp32 | |||||
| /// @return Return float/fp32 value of fp16_t | |||||
| float ToFloat() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Convert fp16_t to double/fp64 | |||||
| /// @return Return double/fp64 value of fp16_t | |||||
| double ToDouble() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Convert fp16_t to int8_t | |||||
| /// @return Return int8_t value of fp16_t | |||||
| int8_t ToInt8() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Convert fp16_t to uint8_t | |||||
| /// @return Return uint8_t value of fp16_t | |||||
| uint8_t ToUInt8() const; | |||||
| /// @ingroup fp16_t conversion | |||||
| /// @brief Convert fp16_t to int16_t | |||||
| /// @return Return int16_t value of fp16_t | |||||
| int16_t ToInt16() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Convert fp16_t to uint16_t | |||||
| /// @return Return uint16_t value of fp16_t | |||||
| uint16_t ToUInt16() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Convert fp16_t to int32_t | |||||
| /// @return Return int32_t value of fp16_t | |||||
| int32_t ToInt32() const; | |||||
| /// @ingroup fp16_t math conversion | |||||
| /// @brief Convert fp16_t to uint32_t | |||||
| /// @return Return uint32_t value of fp16_t | |||||
| uint32_t ToUInt32() const; | |||||
| }; | |||||
| /// @ingroup fp16_t public method | |||||
| /// @param [in] val signature is negative | |||||
| /// @param [in|out] s sign of fp16_t object | |||||
| /// @param [in|out] e exponent of fp16_t object | |||||
| /// @param [in|out] m mantissa of fp16_t object | |||||
| /// @brief Extract the sign, exponent and mantissa of a fp16_t object | |||||
| void ExtractFp16(const uint16_t &val, uint16_t &s, int16_t &e, uint16_t &m); | |||||
| /// @ingroup fp16_t public method | |||||
| /// @param [in] negative sign is negative | |||||
| /// @param [in|out] man mantissa to be reverse | |||||
| /// @brief Calculate a mantissa's complement (add ont to it's radix-minus-one complement) | |||||
| /// @return Return complement of man | |||||
| template<typename T> | |||||
| void ReverseMan(bool negative, T &man) { | |||||
| if (negative) { | |||||
| man = (~(man)) + 1; | |||||
| } | |||||
| } | |||||
| /// @ingroup fp16_t public method | |||||
| /// @param [in] e_a exponent of one fp16_t/float number | |||||
| /// @param [in] m_a mantissa of one fp16_t/float number | |||||
| /// @param [in] e_b exponent of another fp16_t/float number | |||||
| /// @param [in] m_b mantissa of another fp16_t/float number | |||||
| /// @brief choose mantissa to be shift right whoes exponent is less than another one | |||||
| /// @return Return mantissawhoes exponent is less than another one | |||||
| template<typename T> | |||||
| T MinMan(const int16_t &e_a, T &m_a, const int16_t &e_b, T &m_b) { | |||||
| return (e_a > e_b) ? m_b : m_a; | |||||
| } | |||||
| /// @ingroup fp16_t public method | |||||
| /// @param [in] man mantissa to be operate | |||||
| /// @param [in] shift right shift bits | |||||
| /// @brief right shift a mantissa | |||||
| /// @return Return right-shift mantissa | |||||
| template<typename T> | |||||
| T RightShift(T man, int16_t shift) { | |||||
| int bits = sizeof(T) * 8; // one byte have 8 bits | |||||
| T mask = (((T) 1u) << ((unsigned int) (bits - 1))); | |||||
| for (int i = 0; i < shift; i++) { | |||||
| man = ((man & mask) | (man >> 1)); | |||||
| } | |||||
| return man; | |||||
| } | |||||
| /// @ingroup fp16_t public method | |||||
| /// @param [in] e_a exponent of one temp fp16_t number | |||||
| /// @param [in] m_a mantissa of one temp fp16_t number | |||||
| /// @param [in] e_b exponent of another temp fp16_t number | |||||
| /// @param [in] m_b mantissa of another temp fp16_t number | |||||
| /// @brief Get mantissa sum of two temp fp16_t numbers, T support types: uint16_t/uint32_t/uint64_t | |||||
| /// @return Return mantissa sum | |||||
| template<typename T> | |||||
| T GetManSum(int16_t e_a, const T &m_a, int16_t e_b, const T &m_b) { | |||||
| T sum = 0; | |||||
| if (e_a != e_b) { | |||||
| T m_tmp = 0; | |||||
| int16_t e_tmp = std::abs(e_a - e_b); | |||||
| if (e_a > e_b) { | |||||
| m_tmp = m_b; | |||||
| m_tmp = RightShift(m_tmp, e_tmp); | |||||
| sum = m_a + m_tmp; | |||||
| } else { | |||||
| m_tmp = m_a; | |||||
| m_tmp = RightShift(m_tmp, e_tmp); | |||||
| sum = m_tmp + m_b; | |||||
| } | |||||
| } else { | |||||
| sum = m_a + m_b; | |||||
| } | |||||
| return sum; | |||||
| } | |||||
| /// @ingroup fp16_t public method | |||||
| /// @param [in] bit0 whether the last preserved bit is 1 before round | |||||
| /// @param [in] bit1 whether the abbreviation's highest bit is 1 | |||||
| /// @param [in] bitLeft whether the abbreviation's bits which not contain highest bit grater than 0 | |||||
| /// @param [in] man mantissa of a fp16_t or float number, support types: uint16_t/uint32_t/uint64_t | |||||
| /// @param [in] shift abbreviation bits | |||||
| /// @brief Round fp16_t or float mantissa to nearest value | |||||
| /// @return Returns true if round 1,otherwise false; | |||||
| template<typename T> | |||||
| T ManRoundToNearest(bool bit0, bool bit1, bool bitLeft, T man, uint16_t shift = 0) { | |||||
| man = (man >> shift) + ((bit1 && (bitLeft || bit0)) ? 1 : 0); | |||||
| return man; | |||||
| } | |||||
| /// @ingroup fp16_t public method | |||||
| /// @param [in] man mantissa of a float number, support types: uint16_t/uint32_t/uint64_t | |||||
| /// @brief Get bit length of a uint32_t number | |||||
| /// @return Return bit length of man | |||||
| template<typename T> | |||||
| int16_t GetManBitLength(T man) { | |||||
| int16_t len = 0; | |||||
| while (man) { | |||||
| man >>= 1; | |||||
| len++; | |||||
| } | |||||
| return len; | |||||
| } | |||||
| } // namespace parser | |||||
| } // namespace ge | |||||
| #endif // GE_PARSER_COMMON_FP16_T_H_ | |||||
| @@ -1,494 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| namespace ge{ | |||||
| namespace parser { | |||||
| const char *DATA = "Data"; | |||||
| const char *AIPPDATA = "AippData"; | |||||
| const char *CONVOLUTION = "Convolution"; | |||||
| const char *CORRELATION = "Correlation"; | |||||
| const char *CORRELATIONV2 = "Correlation_V2"; | |||||
| const char *DECONVOLUTION = "Deconvolution"; | |||||
| const char *POOLING = "Pooling"; | |||||
| const char *ELTWISE = "Eltwise"; | |||||
| const char *RELU = "ReLU"; | |||||
| const char *RELU6 = "ReLU6"; | |||||
| const char *SIGMOID = "Sigmoid"; | |||||
| const char *ABSVAL = "AbsVal"; | |||||
| const char *TANH = "TanH"; | |||||
| const char *PRELU = "PReLU"; | |||||
| const char *BATCHNORM = "BatchNorm"; | |||||
| const char *FUSIONBATCHNORM = "FusionBatchNorm"; | |||||
| const char *SCALE = "Scale"; | |||||
| const char *FULL_CONNECTION = "FullConnection"; | |||||
| const char *SOFTMAX = "Softmax"; | |||||
| const char *PLUS = "Plus"; | |||||
| const char *ACTIVATION = "Activation"; | |||||
| const char *FLATTEN = "Flatten"; | |||||
| const char *ADD = "Add"; | |||||
| const char *SUB = "Sub"; | |||||
| const char *MUL = "Mul"; | |||||
| const char *MATMUL = "MatMul"; | |||||
| const char *RSQRT = "Rsqrt"; | |||||
| const char *BIASADD = "BiasAdd"; | |||||
| const char *RESHAPE = "Reshape"; | |||||
| const char *REFORMAT = "ReFormat"; | |||||
| const char *DEPCONVOLUTION = "ConvolutionDepthwise"; | |||||
| const char *DROPOUT = "Dropout"; | |||||
| const char *DROPOUTGENMASK = "DropOutGenMask"; | |||||
| const char *DROPOUTDOMASK = "DropOutDoMask"; | |||||
| const char *CONCAT = "Concat"; | |||||
| const char *ROIPOOLING = "ROIPooling"; | |||||
| const char *PROPOSAL = "Proposal"; | |||||
| const char *FSRDETECTIONOUTPUT = "FSRDetectionOutput"; | |||||
| const char *DETECTIONPOSTPROCESS = "Detectpostprocess"; | |||||
| const char *LRN = "LRN"; | |||||
| const char *TRANSDATA = "TransData"; | |||||
| const char *PERMUTE = "Permute"; | |||||
| const char *SSDNORMALIZE = "SSDNormalize"; | |||||
| const char *SSDPRIORBOX = "SSDPriorBox"; | |||||
| const char *NETOUTPUT = "NetOutput"; | |||||
| const char *SSDDETECTIONOUTPUT = "SSDDetectionOutput"; | |||||
| const char *REFINEDETDETECTIONOUTPUT = "RefinedetDetectionOutput"; | |||||
| const char *CHANNELAXPY = "ChannelAxpy"; | |||||
| const char *PSROIPOOLING = "PSROIPooling"; | |||||
| const char *POWER = "Power"; | |||||
| const char *POW = "Pow"; | |||||
| const char *ROIALIGN = "ROIAlign"; | |||||
| const char *PYTHON = "Python"; | |||||
| const char *FREESPACEEXTRACT = "FreespaceExtract"; | |||||
| const char *SPATIALTF = "SpatialTransform"; | |||||
| const char *SHAPE = "Shape"; | |||||
| const char *SHAPEN = "ShapeN"; | |||||
| const char *ARGMAX = "ArgMax"; | |||||
| const char *GATHERND = "GatherNd"; | |||||
| const char *GATHER = "Gather"; | |||||
| const char *REALDIV = "RealDiv"; | |||||
| const char *PACK = "Pack"; | |||||
| const char *SLICE = "Slice"; | |||||
| const char *SLICED = "SliceD"; | |||||
| const char *FLOORDIV = "FloorDiv"; | |||||
| const char *SQUEEZE = "Squeeze"; | |||||
| const char *UNSQUEEZE = "Unsqueeze"; | |||||
| const char *STRIDEDSLICE = "StridedSlice"; | |||||
| const char *RANGE = "Range"; | |||||
| const char *RPNPROPOSALS = "RpnProposals"; | |||||
| const char *DECODEBBOX = "DecodeBbox"; | |||||
| const char *PAD = "Pad"; | |||||
| const char *PADV2 = "PadV2"; | |||||
| const char *MIRRORPAD = "MirrorPad"; | |||||
| const char *TILE = "Tile"; | |||||
| const char *SIZE = "Size"; | |||||
| const char *CLIPBOXES = "ClipBoxes"; | |||||
| const char *FASTRCNNPREDICTIONS = "FastrcnnPredictions"; | |||||
| const char *SPLIT = "Split"; | |||||
| const char *SPLITV = "SplitV"; | |||||
| const char *EXPANDDIMS = "ExpandDims"; | |||||
| const char *EMPTY = "Empty"; | |||||
| const char *MEAN = "Mean"; | |||||
| const char *GREATER = "Greater"; | |||||
| const char *SWITCH = "Switch"; | |||||
| const char *SWITCHN = "SwitchN"; | |||||
| const char *MERGE = "Merge"; | |||||
| const char *SYMBOLICGRADIENT = "SymbolicGradient"; | |||||
| const char *REMOTECALL = "RemoteCall"; | |||||
| const char *_IF = "_If"; | |||||
| const char *STATELESSIF = "StatelessIf"; | |||||
| const char *IF = "If"; | |||||
| const char *CASE = "Case"; | |||||
| const char *_WHILE = "_While"; | |||||
| const char *WHILE = "While"; | |||||
| const char *STATELESSWHILE = "StatelessWhile"; | |||||
| const char *FOR = "For"; | |||||
| const char *PARTITIONEDCALL = "PartitionedCall"; | |||||
| const char *STATEFULPARTITIONEDCALL = "StatefulPartitionedCall"; | |||||
| const char *FAKEPARAM = "FakeParam"; | |||||
| const char *TRANSPOSE = "Transpose"; | |||||
| const char *TRANSPOSED = "TransposeD"; | |||||
| const char *CAST = "Cast"; | |||||
| const char *REGION = "Region"; | |||||
| const char *YOLO = "Yolo"; | |||||
| const char *YOLODETECTIONOUTPUT = "YoloDetectionOutput"; | |||||
| const char *FILL = "Fill"; | |||||
| const char *REVERSE = "Reverse"; | |||||
| const char *UNPACK = "Unpack"; | |||||
| const char *YOLO2REORG = "Yolo2Reorg"; | |||||
| const char *REDUCESUM = "ReduceSum"; | |||||
| const char *SUM = "Sum"; | |||||
| const char *CONSTANT = "Const"; | |||||
| const char *RESIZEBILINEAR = "ResizeBilinear"; | |||||
| const char *RESIZEBILINEARGRAD = "ResizeBilinearGrad"; | |||||
| const char *MAXIMUM = "Maximum"; | |||||
| const char *FRAMEWORKOP = "FrameworkOp"; | |||||
| const char *ARG = "_Arg"; | |||||
| const char *FUSEDBATCHNORMGRAD = "FusedBatchNormGrad"; | |||||
| const char *LSTM = "LSTM"; | |||||
| const char *HIGHWAY = "HighWay"; | |||||
| const char *RNN = "RNN"; | |||||
| const char *ATTENTIONDECODER = "AttentionDecoder"; | |||||
| const char *LOGICAL_NOT = "LogicalNot"; | |||||
| const char *LOGICAL_AND = "LogicalAnd"; | |||||
| const char *LOGICAL_OR = "LogicalOr"; | |||||
| const char *EQUAL = "Equal"; | |||||
| const char *NOTEQUAL = "NotEqual"; | |||||
| const char *INTERP = "Interp"; | |||||
| const char *SHUFFLECHANNEL = "ShuffleChannel"; | |||||
| const char *AIPP = "Aipp"; | |||||
| const char *MULTISHAPE = "MultiShape"; | |||||
| const char *RECIPROCAL = "Reciprocal"; | |||||
| const char *SELU = "Selu"; | |||||
| const char *ELU = "Elu"; | |||||
| const char *ACOSH = "Acosh"; | |||||
| const char *ASINH = "Asinh"; | |||||
| const char *MINIMUM = "Minimum"; | |||||
| const char *CLIP = "Clip"; | |||||
| const char *L2NORMALIZE = "L2Normalize"; | |||||
| const char *CROPANDRESIZE = "CropAndResize"; | |||||
| const char *UNUSEDCONST = "UnusedConst"; | |||||
| const char *SPARSETODENSE = "SparseToDense"; | |||||
| const char *NONMAXSUPPRESSION = "NonMaxSuppression"; | |||||
| const char *TOPKV2 = "TopKV2"; | |||||
| const char *INVERTPERMUTATION = "InvertPermutation"; | |||||
| const char *MULTINOMIAL = "Multinomial"; | |||||
| const char *REVERSESEQUENCE = "ReverseSequence"; | |||||
| const char *REDUCEPROD = "ReduceProd"; | |||||
| const char *REDUCEMAX = "ReduceMax"; | |||||
| const char *REDUCEMIN = "ReduceMin"; | |||||
| const char *EXTRACTIMAGEPATCHES = "ExtractImagePatches"; | |||||
| const char *SQRT = "Sqrt"; | |||||
| const char *REDUCEALL = "ReduceAll"; | |||||
| const char *RESIZENEARESTNEIGHBOR = "ResizeNearestNeighbor"; | |||||
| const char *SPACETOBATCHND = "SpaceToBatchND"; | |||||
| const char *BATCHTOSPACEND = "BatchToSpaceND"; | |||||
| const char *ASSERT = "Assert"; | |||||
| const char *GREATEREQUAL = "GreaterEqual"; | |||||
| const char *FLOOR = "Floor"; | |||||
| const char *RANDOMUNIFORM = "RandomUniform"; | |||||
| const char *BATCHMATMUL = "BatchMatMul"; | |||||
| const char *SPACETODEPTH = "SpaceToDepth"; | |||||
| const char *DEPTHTOSPACE = "DepthToSpace"; | |||||
| const char *RINT = "Rint"; | |||||
| const char *ATAN = "Atan"; | |||||
| const char *ATAN2 = "Atan2"; | |||||
| const char *ATANH = "Atanh"; | |||||
| const char *ACOS = "Acos"; | |||||
| const char *ASIN = "Asin"; | |||||
| const char *NEG = "Neg"; | |||||
| const char *LOG = "Log"; | |||||
| const char *TAN = "Tan"; | |||||
| const char *ROUND = "Round"; | |||||
| const char *UPSAMPLE = "Upsample"; | |||||
| const char *FLOORMOD = "FloorMod"; | |||||
| const char *LESS = "Less"; | |||||
| const char *LESSEQUAL = "LessEqual"; | |||||
| const char *ONEHOT = "OneHot"; | |||||
| const char *REFSWITCH = "RefSwitch"; | |||||
| const char *REFMERGE = "RefMerge"; | |||||
| const char *ENTER = "Enter"; | |||||
| const char *REFENTER = "RefEnter"; | |||||
| const char *LOOPCOND = "LoopCond"; | |||||
| const char *NEXTITERATION = "NextIteration"; | |||||
| const char *REFNEXTITERATION = "RefNextIteration"; | |||||
| const char *EXIT = "Exit"; | |||||
| const char *REFEXIT = "RefExit"; | |||||
| const char *CONTROLTRIGGER = "ControlTrigger"; | |||||
| const char *ZEROSLIKE = "ZerosLike"; | |||||
| const char *EXP = "Exp"; | |||||
| const char *WHERE = "Where"; | |||||
| const char *FAKEQUANTWITHMINMAXVARS = "FakeQuantWithMinMaxVars"; | |||||
| const char *SOFTPLUS = "Softplus"; | |||||
| const char *SOFTSIGN = "Softsign"; | |||||
| const char *COSH = "Cosh"; | |||||
| const char *SINH = "Sinh"; | |||||
| const char *SQUAREDDIFFERENCE = "SquaredDifference"; | |||||
| const char *REQUIREDSPACETOBATCHPADDINGS = "RequiredSpaceToBatchPaddings"; // for retinanet scope fusion | |||||
| const char *SSDPOSTPROCESSOR = "SSDPostProcessor"; | |||||
| const char *RETINANETBOXES = "RetinanetBoxes"; | |||||
| const char *RETINAMULTIANCHORS = "RetinaMultiAnchor"; | |||||
| const char *RETINANETCLIPPEDBOXES = "RetinanetClippedBoxes"; | |||||
| const char *RETINANETFILTEREDDETECTIONS = "RetinanetFilteredDetections"; | |||||
| const char *RETINANETPOSTPROCESSOR = "RetinanetPostProcessor"; | |||||
| const char *RETINANETANCHORS = "RetinanetAnchors"; | |||||
| const char *FASTERRCNNMAP = "FasterRCNNMap"; | |||||
| const char *FASTERRCNNMAP1 = "FasterRCNNMap1"; | |||||
| const char *FASTERRCNNSECONDSTAGEPOSTPROCESSOR = "FasterRCNNSecondStagePostprocessor"; | |||||
| const char *FASTERRCNNROIINTERPOOLING = "FasterRCNNROIInterPooling"; | |||||
| const char *FASTERRCNNFIRSTSTAGEPOSTPROCESSOR = "FasterRCNNFirstStagePostprocessor"; | |||||
| const char *FASTERRCNNGRIDANCHORGENERATOR = "FasterRCNNGridAnchorGenerator"; | |||||
| const char *ROIINTERPOOLING = "ROIInterPooling"; | |||||
| const char *FASTERRCNNCLIPTOWINDOW = "FasterRCNNClipToWindow"; | |||||
| const char *EMBEDLOOKUP = "EmbedLookup"; | |||||
| const char *HASHLOOKUP = "HashLookup"; | |||||
| const char *LSH_PROJ = "LshProject"; | |||||
| const char *SVDF = "SVDF"; | |||||
| const char *SSDANCHORGENERATOR = "SSDAnchorGenerator"; | |||||
| const char *IDENTITY = "Identity"; | |||||
| const char *IDENTITYN = "IdentityN"; | |||||
| const char *PLACEHOLDERWITHDEFAULT = "PlaceholderWithDefault"; | |||||
| const char *SELECT = "Select"; | |||||
| const char *GETSPAN = "GetSpan"; | |||||
| const char *STOPGRADIENT = "StopGradient"; | |||||
| const char *PREVENTGRADIENT = "PreventGradient"; | |||||
| const char *GUARANTEECONST = "GuaranteeConst"; | |||||
| const char *BROADCASTGRADIENTARGS = "BroadcastGradientArgs"; | |||||
| const char *BROADCASTARGS = "BroadcastArgs"; | |||||
| const char *CONFUSIONMATRIX = "ConfusionMatrix"; | |||||
| const char *RANK = "Rank"; | |||||
| const char *PLACEHOLDER = "PlaceHolder"; | |||||
| const char *END = "End"; | |||||
| const char *BASICLSTMCELL = "BasicLSTMCell"; | |||||
| const char *GETNEXT = "GetNext"; | |||||
| const char *INITDATA = "InitData"; | |||||
| const char *REFIDENTITY = "RefIdentity"; | |||||
| const char *BITCAST = "Bitcast"; | |||||
| /***************Ann special operator*************************/ | |||||
| const char *ANN_MEAN = "AnnMean"; | |||||
| const char *ANN_CONVOLUTION = "AnnConvolution"; | |||||
| const char *ANN_DEPCONVOLUTION = "AnnDepthConv"; | |||||
| const char *ANN_FULLCONNECTION = "AnnFullConnection"; | |||||
| const char *ANN_NETOUTPUT = "AnnNetOutput"; | |||||
| const char *ANN_DATA = "AnnData"; | |||||
| const char *ANN_RESHAPE = "AnnReshape"; | |||||
| const char *ANN_ADD = "AnnAdd"; | |||||
| const char *ANN_MUL = "AnnMul"; | |||||
| const char *ANN_SUB = "AnnSub"; | |||||
| const char *ANN_DIV = "AnnDiv"; | |||||
| const char *ANN_DEQUANTIZE = "AnnDequant"; | |||||
| const char *ANN_QUANTIZE = "AnnQuant"; | |||||
| const char *ANN_PAD = "AnnPad"; | |||||
| const char *ANN_RESIZE_BILINEAR = "AnnResizeBilinear"; | |||||
| /***************************************************/ | |||||
| /******************Training operator*************************/ | |||||
| const char *GATHERV2 = "GatherV2"; | |||||
| const char *CONVGRADFILTER = "Conv2DBackpropFilter"; | |||||
| const char *CONV2D = "Conv2D"; | |||||
| const char *CONV2DBACKPROPINPUT = "Conv2DBackpropInput"; | |||||
| const char *FUSEDBATCHNORM = "FusedBatchNorm"; | |||||
| const char *BIASADDGRAD = "BiasAddGrad"; | |||||
| const char *ACTIVATIONGRAD = "ReluGrad"; | |||||
| const char *MAXPOOLWITHARGMAX = "MaxPoolWithArgmax"; | |||||
| const char *MAXPOOLGRADWITHARGMAX = "MaxPoolGradWithArgmax"; | |||||
| const char *SPARSESOFTMAXCROSSENTROPYWITHLOGITS = "SparseSoftmaxCrossEntropyWithLogits"; | |||||
| const char *SNAPSHOT = "Snapshot"; | |||||
| const char *VAR = "Var"; | |||||
| const char *MEANGRAD = "MeanGrad"; | |||||
| const char *TRANSLATE = "Translate"; | |||||
| const char *ADDN = "AddN"; | |||||
| const char *L2LOSS = "L2Loss"; | |||||
| const char *MULTIPLY = "Multiply"; | |||||
| const char *HUBERLOSSGRAD = "HuberLossGrad"; | |||||
| const char *HUBERLOSS = "HuberLoss"; | |||||
| const char *NEGATIVE = "Negative"; | |||||
| const char *SSDCAST = "SSDCast"; | |||||
| const char *SPARSESOFTMAXCROSSENTROPY = "SsdSparseSoftmaxCrossEntropy"; | |||||
| const char *SPARSESOFTMAXCROSSENTROPYGRAD = "SsdSparseSoftmaxCrossEntropyGrad"; | |||||
| const char *SSDSQUEEZEFUSION = "SsdSqueezeFusion"; | |||||
| const char *CONCATFOUR2FIVE = "ConcatFour2Five"; | |||||
| const char *CONCATFIVE2FOUR = "ConcatFive2Four"; | |||||
| const char *SSDREALDIVTILEMUL = "SSDRealdivTileMul"; | |||||
| const char *SSDSUMMULREALDIVMEAN = "SSDSumMulRealdivMean"; | |||||
| const char *VARIABLEV2 = "VariableV2"; | |||||
| const char *VARHANDLEOP = "VarHandleOp"; | |||||
| const char *TEMPORARYVARIABLE = "TemporaryVariable"; | |||||
| const char *DESTROYTEMPORARYVARIABLE = "DestroyTemporaryVariable"; | |||||
| const char *VARIABLE = "Variable"; | |||||
| const char *ASSIGN = "Assign"; | |||||
| const char *ASSIGNVARIABLEOP = "AssignVariableOp"; | |||||
| const char *ASSIGNADD = "AssignAdd"; | |||||
| const char *ASSIGNADDVARIABLEOP = "AssignAddVariableOp"; | |||||
| const char *ASSIGNSUB = "AssignSub"; | |||||
| const char *ASSIGNSUBVARIABLEOP = "AssignSubVariableOp"; | |||||
| const char *APPLYMOMENTUM = "ApplyMomentum"; | |||||
| const char *RESOURCEAPPLYMOMENTUM = "ResourceApplyMomentum"; | |||||
| const char *SGD = "SGD"; | |||||
| const char *NOOP = "NoOp"; | |||||
| const char *READVARIABLEOP = "ReadVariableOp"; | |||||
| const char *PARALLELCONCATSTART = "_ParallelConcatStart"; | |||||
| const char *CONSTANTOP = "Constant"; | |||||
| const char *DEPTHWISECONV2DBACKPROPFILTER = "DepthwiseConv2dNativeBackpropFilter"; | |||||
| const char *DEPTHWISECONV2DBACKPORPINPUT = "DepthwiseConv2dNativeBackpropInput"; | |||||
| const char *DEPTHWISECONV2DFORWARDNATIVE = "DepthwiseConv2dNative"; | |||||
| const char *DROPOUTGRAD = "DropOutGrad"; | |||||
| const char *APPLYRMSPROPMIXEDPRECISION = "apply_rms_prop_mixed_precision"; | |||||
| const char *APPLYRMSPROP = "ApplyRMSProp"; | |||||
| const char *RELU6GRAD = "Relu6Grad"; | |||||
| const char *AVGPOOLGRAD = "AvgPoolGrad"; | |||||
| const char *CONCATV2 = "ConcatV2"; | |||||
| const char *CONCATOFFSET = "ConcatOffset"; | |||||
| const char *LAYERNORMGRAD = "LayerNormGrad"; | |||||
| const char *LAYERNORM = "LayerNorm"; | |||||
| const char *LARS = "Lars"; | |||||
| const char *DYNAMICSTITCH = "DynamicStitch"; | |||||
| /***************************************************/ | |||||
| const char *SQUARE = "Square"; | |||||
| const char *HCOMBROADCAST = "HcomBroadcast"; | |||||
| const char *HCOMALLGATHER = "HcomAllGather"; | |||||
| const char *HCOMALLREDUCE = "HcomAllReduce"; | |||||
| const char *HCOMREDUCESCATTER = "HcomReduceScatter"; | |||||
| const char *HCOMSEND = "HcomSend"; | |||||
| const char *HCOMRECEIVE = "HcomReceive"; | |||||
| const char *HCOMREMOTEREAD = "HcomRemoteRead"; | |||||
| const char *HCOMREMOTEWRITE = "HcomRemoteWrite"; | |||||
| const char *VARASSIGN = "VarAssign"; | |||||
| const char *VARISINITIALIZEDOP = "VarIsInitializedOp"; | |||||
| const char *LogTimeStamp = "LogTimeStamp"; | |||||
| const char *ISVARIABLEINITIALIZED = "IsVariableInitialized"; | |||||
| const char *STREAMSWITCH = "StreamSwitch"; | |||||
| const char *STREAMSWITCHN = "StreamSwitchN"; | |||||
| const char *STREAMACTIVE = "StreamActive"; | |||||
| const char *MEMCPYASYNC = "MemcpyAsync"; | |||||
| const char *MEMCPYADDRASYNC = "MemcpyAddrAsync"; | |||||
| const char *STREAMMERGE = "StreamMerge"; | |||||
| const char *ENDGRAPH = "EndGraph"; | |||||
| const char *SEND = "Send"; | |||||
| const char *RECV = "Recv"; | |||||
| const char *ENDOFSEQUENCE = "EndOfSequence"; | |||||
| const char *LABELSET = "LabelSet"; | |||||
| const char *LABELGOTO = "LabelGoto"; | |||||
| const char *LABELGOTOEX = "LabelGotoEx"; | |||||
| const char *LABELSWITCH = "LabelSwitch"; | |||||
| const char *LABELSWITCHBYINDEX = "LabelSwitchByIndex"; | |||||
| const char *ATOMICADDRCLEAN = "AtomicAddrClean"; | |||||
| const char *ABS_GRAD = "AbsGrad"; | |||||
| const char *ACCUMULATE_N_V2 = "AccumulateNV2"; | |||||
| const char *ACOS_GRAD = "AcosGrad"; | |||||
| const char *ACOSH_GRAD = "AcoshGrad"; | |||||
| const char *ANY = "Any"; | |||||
| const char *APPROXIMATE_EQUAL = "ApproximateEqual"; | |||||
| const char *ASIN_GRAD = "AsinGrad"; | |||||
| const char *ASINH_GRAD = "AsinhGrad"; | |||||
| const char *ATAN_GRAD = "AtanGrad"; | |||||
| const char *BROADCAST_TO = "BroadcastTo"; | |||||
| const char *ELU_GRAD = "EluGrad"; | |||||
| const char *ADD_V2 = "AddV2"; | |||||
| const char *DATAFORMATDIMMAP = "DataFormatDimMap"; | |||||
| const char *DATAFORMATVECPERMUTE = "DataFormatVecPermute"; | |||||
| const char *BESSELI0E = "BesselI0e"; | |||||
| const char *BESSELI1E = "BesselI1e"; | |||||
| const char *APPLYADADELTA = "ApplyAdadelta"; | |||||
| const char *APPLYADAGRAD = "ApplyAdagrad"; | |||||
| const char *APPLYADAGRADDA = "ApplyAdagradDA"; | |||||
| const char *APPLYADAM = "ApplyAdam"; | |||||
| const char *APPLYADAMAX = "ApplyAdaMax"; | |||||
| const char *APPLYADDSIGN = "ApplyAddSign"; | |||||
| const char *APPLYCENTEREDRMSPROP = "ApplyCenteredRMSProp"; | |||||
| const char *APPLYFTRL = "ApplyFtrl"; | |||||
| const char *APPLYFTRLV2 = "ApplyFtrlV2"; | |||||
| const char *APPLYGRADIENTDESCENT = "ApplyGradientDescent"; | |||||
| const char *APPLYPOWERSIGN = "ApplyPowerSign"; | |||||
| const char *APPLYPROXIMALADAGRAD = "ApplyProximalAdagrad"; | |||||
| const char *APPLYPROXIMALGRADIENTDESCENT = "ApplyProximalGradientDescent"; | |||||
| const char *DEQUANTIZE = "Dequantize"; | |||||
| const char *FOCAL_LOSS = "FocalLoss"; | |||||
| const char *FOCAL_LOSS_GRAD = "FocalLossGrad"; | |||||
| const char *SMOOTHL1_LOSS = "SmoothL1Loss"; | |||||
| const char *SMOOTHL1_LOSS_grad = "SmoothL1LossGrad"; | |||||
| const char *REDUCEMEAN = "ReduceMean"; | |||||
| const char *CONCAT_V2 = "ConcatV2"; | |||||
| const char *ONEHOT_V2 = "OneHotV2"; | |||||
| const char *SLICE_V2 = "SliceV2"; | |||||
| const char *TILE_V2 = "TileV2"; | |||||
| const char *SUM_V2 = "SumV2"; | |||||
| // Common type when the operator has the same name | |||||
| const char *DETECTIONOUTPUT = "DetectionOutput"; | |||||
| // Custom operator | |||||
| const char *CUSTOMOP = "CustomOp"; | |||||
| const char *CUSTOMOP_NCHW = "CustomOpNchw"; | |||||
| const char *CUSTOMOP_NHWC = "CustomOpNhwc"; | |||||
| const char *CUSTOMOP_NC1HWC0 = "CustomOpNc1hwc0"; | |||||
| // Depthwise 4d_2_6d,6d_2_4d | |||||
| const char *DEPTHWISEWEIGHT4D26D = "depthwise_weight_4d_2_6d"; | |||||
| const char *DEPTHWISEWEIGHT6D24D = "depthwise_weight_6d_2_4d"; | |||||
| const char *SQRTGRAD = "SqrtGrad"; | |||||
| const char *SIGMOIDGRAD = "SigmoidGrad"; | |||||
| const char *TRANSSHAPE = "TransShape"; | |||||
| // Horovod operator | |||||
| const char *HVDCALLBACKALLREDUCE = "HorovodAllreduce"; | |||||
| const char *HVDCALLBACKALLGATHER = "HorovodAllgather"; | |||||
| const char *HVDCALLBACKBROADCAST = "HorovodBroadcast"; | |||||
| const char *HVDWAIT = "HorovodWait"; | |||||
| /// | |||||
| /// @brief Magic number of model file | |||||
| /// | |||||
| const uint32_t MODEL_FILE_MAGIC_NUM = 0x444F4D49; // magic number | |||||
| /// | |||||
| /// @brief Model head length | |||||
| /// | |||||
| const uint32_t MODEL_FILE_HEAD_LEN = 256; | |||||
| const uint32_t MODEL_VERSION = 0x10000000; ///< Model version 1.0/// | |||||
| /// | |||||
| /// @ingroup domi_omg | |||||
| /// @brief alpha default value | |||||
| /// | |||||
| const float ALPHA_DEFAULT_VALUE = 1.0; | |||||
| /// | |||||
| /// @ingroup domi_omg | |||||
| /// @brief beta default value | |||||
| /// | |||||
| const float BETA_DEFAULT_VALUE = 0.0; | |||||
| /// | |||||
| /// @ingroup domi_omg | |||||
| /// @brief Input node type | |||||
| /// | |||||
| const std::string INPUT_TYPE = "Input"; | |||||
| const std::string DUMMY_DATA = "DummyData"; | |||||
| // for fusion op plugin | |||||
| const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; | |||||
| const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; | |||||
| const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; | |||||
| /// | |||||
| /// @ingroup domi_omg | |||||
| /// @brief DATA node type | |||||
| /// | |||||
| const std::string DATA_TYPE = "Data"; | |||||
| /// | |||||
| /// @ingroup domi_omg | |||||
| /// @brief Frame operator type | |||||
| /// | |||||
| const std::string FRAMEWORK_OP_TYPE = "FrameworkOp"; | |||||
| /// | |||||
| /// @ingroup domi_omg | |||||
| /// @brief Convolution node type | |||||
| /// | |||||
| const std::string NODE_NAME_NET_OUTPUT = "Node_Output"; | |||||
| } // namespace parser | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,221 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parser_utils.h" | |||||
| #include "external/ge/ge_api_types.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/util.h" | |||||
| #include "graph/anchor.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "register/op_registry.h" | |||||
| namespace ge { | |||||
| namespace { | |||||
| Status HandleNewOp(const NodePtr &node, const ComputeGraphPtr &compute_graph, const NodePtr &new_node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| GE_CHECK_NOTNULL(new_node); | |||||
| if (new_node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "Set owner graph for node:%s failed.", new_node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| auto op_desc = new_node->GetOpDesc(); | |||||
| static std::atomic_long new_node_index(0); | |||||
| auto new_name = "PartitionedCall_" + new_node->GetName() + "_" + to_string(new_node_index++); | |||||
| op_desc->SetName(new_name); | |||||
| bool ret = ge::AttrUtils::SetListStr(op_desc, | |||||
| ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, | |||||
| std::move(std::vector<std::string>{node->GetName()})); | |||||
| if (!ret) { | |||||
| GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), op_desc->GetName().c_str()); | |||||
| } | |||||
| GELOGD("Handle new op[%s] for node[%s] success.", new_node->GetName().c_str(), node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| Status ParserUtils::ExpandOneToManyGraph(Graph &graph) { | |||||
| GELOGD("Begin run ParserUtils::ExpandOneToManyGraph."); | |||||
| ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); | |||||
| for (const auto &n : compute_graph->GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL(n); | |||||
| std::string ori_type; | |||||
| (void)AttrUtils::GetStr(n->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, ori_type); | |||||
| domi::ParseOpToGraphFunc parse_op_to_graph_func = | |||||
| domi::OpRegistry::Instance()->GetParseOpToGraphFunc(n->GetType(), ori_type); | |||||
| if (parse_op_to_graph_func == nullptr) { | |||||
| GELOGD("node:%s type:%s ori type:%s has no parse_op_to_graph_func.", | |||||
| n->GetName().c_str(), n->GetType().c_str(), ori_type.c_str()); | |||||
| continue; | |||||
| } | |||||
| GELOGI("node:%s type:%s ori type:%s has registered one to many parser func.", | |||||
| n->GetName().c_str(), n->GetType().c_str(), ori_type.c_str()); | |||||
| Graph subgraph("one_to_many_graph"); | |||||
| Operator op = OpDescUtils::CreateOperatorFromNode(n); | |||||
| Status ret = parse_op_to_graph_func(op, subgraph); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "Get one to many graph failed for op:%s.", op.GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| ret = ExpandNodeToSubgraph(subgraph, n, graph); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "Expand one to many graph failed for op:%s.", op.GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| GELOGD("run ParserUtils::ExpandOneToManyGraph success."); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph) { | |||||
| ComputeGraphPtr sub_compute_graph = GraphUtils::GetComputeGraph(subgraph); | |||||
| GE_CHECK_NOTNULL(sub_compute_graph); | |||||
| ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| // add subgraph node to graph. | |||||
| std::vector<NodePtr> input_nodes; | |||||
| for (const auto &n : sub_compute_graph->GetDirectNode()) { | |||||
| auto new_node = compute_graph->AddNode(n); | |||||
| GE_CHECK_NOTNULL(new_node); | |||||
| if (HandleNewOp(node, compute_graph, new_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "Handle new op[%s] for node[%s] failed.", new_node->GetName().c_str(), node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (new_node->GetType() == "Data") { | |||||
| input_nodes.emplace_back(new_node); | |||||
| } | |||||
| } | |||||
| // handle input context. | |||||
| Status ret = HandleInputContext(node, input_nodes, compute_graph); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "run ParserUtils::HandleInputContext failed."); | |||||
| return FAILED; | |||||
| } | |||||
| // handle output context. | |||||
| std::vector<std::pair<NodePtr, int32_t>> out_node_index = sub_compute_graph->GetGraphOutNodesInfo(); | |||||
| ret = HandleOutputContext(node, out_node_index); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "run ParserUtils::HandleOutputContext failed."); | |||||
| return FAILED; | |||||
| } | |||||
| graphStatus graph_status = GraphUtils::RemoveNodeWithoutRelink(compute_graph, node); | |||||
| if (graph_status != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "Remove node:%s failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| graph_status = compute_graph->TopologicalSorting(); | |||||
| if (graph_status != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "Topological sorting failed."); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ParserUtils::HandleInputContext(const NodePtr &node, | |||||
| const std::vector<NodePtr> &input_nodes, | |||||
| const ComputeGraphPtr &compute_graph) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| for (const auto &in_n : input_nodes) { | |||||
| GE_CHECK_NOTNULL(in_n); | |||||
| int index; | |||||
| if (!AttrUtils::GetInt(in_n->GetOpDesc(), ATTR_NAME_INDEX, index)) { | |||||
| GELOGE(FAILED, "Get attr index of node:%s failed.", in_n->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("Begin to handle input node:%s with index:%d.", in_n->GetName().c_str(), index); | |||||
| // get node's in data anchor and peer out anchor | |||||
| auto node_in_anchor = node->GetInDataAnchor(index); | |||||
| GE_CHECK_NOTNULL(node_in_anchor); | |||||
| auto src_out_anchor = node_in_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(src_out_anchor); | |||||
| auto data_out_anchor = in_n->GetOutDataAnchor(0); | |||||
| GE_CHECK_NOTNULL(data_out_anchor); | |||||
| for (const auto &peer_in_anchor : data_out_anchor->GetPeerInDataAnchors()) { | |||||
| // add data edge | |||||
| graphStatus ret = GraphUtils::RemoveEdge(data_out_anchor, peer_in_anchor); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "remove data out anchor and peer in anchor failed."); | |||||
| return FAILED; | |||||
| } | |||||
| ret = GraphUtils::RemoveEdge(src_out_anchor, node_in_anchor); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "remove node in anchor and peer out anchor failed."); | |||||
| return FAILED; | |||||
| } | |||||
| ret = GraphUtils::AddEdge(src_out_anchor, peer_in_anchor); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "link node's peer out anchor and data's peer in anchor failed."); | |||||
| return FAILED; | |||||
| } | |||||
| // add control edge | |||||
| if (node->GetInControlAnchor() != nullptr) { | |||||
| for (const auto &out_anchor : node->GetInControlAnchor()->GetPeerAnchors()) { | |||||
| graphStatus ret = GraphUtils::AddEdge(out_anchor, peer_in_anchor->GetOwnerNode()->GetInControlAnchor()); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "add control edge failed."); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| graphStatus ret = GraphUtils::RemoveNodeWithoutRelink(compute_graph, in_n); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "remove node:%s failed.", in_n->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ParserUtils::HandleOutputContext(const NodePtr &node, | |||||
| const std::vector<std::pair<NodePtr, int32_t>> &out_node_index) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| GELOGD("The size of out node is %zu", out_node_index.size()); | |||||
| for (size_t index = 0; index < out_node_index.size(); index++) { | |||||
| auto node_out_anchor = node->GetOutDataAnchor(index); | |||||
| if (node_out_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| NodePtr out_node = out_node_index[index].first; | |||||
| int32_t out_index = out_node_index[index].second; | |||||
| GELOGD("Begin to handle output node:%s[%zu] with index:%zu", out_node->GetName().c_str(), out_index, index); | |||||
| auto src_out_anchor = out_node->GetOutDataAnchor(out_index); // get out node's out anchor. | |||||
| GE_CHECK_NOTNULL(src_out_anchor); | |||||
| for (const auto &dest_in_anchor : node_out_anchor->GetPeerInDataAnchors()) { | |||||
| graphStatus ret = GraphUtils::RemoveEdge(node_out_anchor, dest_in_anchor); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "remove node's out anchor and peer in anchor failed."); | |||||
| return FAILED; | |||||
| } | |||||
| ret = GraphUtils::AddEdge(src_out_anchor, dest_in_anchor); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "link node's peer out anchor and out node's out anchor failed."); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_COMMON_PARSER_UTILS_H_ | |||||
| #define PARSER_COMMON_PARSER_UTILS_H_ | |||||
| #include "graph/graph.h" | |||||
| #include "graph/node.h" | |||||
| #include "external/ge/ge_api_error_codes.h" | |||||
| namespace ge { | |||||
| class ParserUtils { | |||||
| public: | |||||
| static Status ExpandOneToManyGraph(Graph &graph); | |||||
| private: | |||||
| static Status ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph); | |||||
| static Status HandleInputContext(const NodePtr &node, | |||||
| const std::vector<NodePtr> &input_nodes, | |||||
| const ComputeGraphPtr &compute_graph); | |||||
| static Status HandleOutputContext(const NodePtr &node, const std::vector<std::pair<NodePtr, int32_t>> &out_node_index); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // PARSER_COMMON_PARSER_UTILS_H_ | |||||
| @@ -1,83 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "parser/common/pass_manager.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "common/debug/log.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| #include "omg/omg_inner_types.h" | |||||
| namespace ge { | |||||
| namespace parser { | |||||
| const vector<std::pair<std::string, GraphPass *>> &PassManager::GraphPasses() const { return names_to_graph_passes_; } | |||||
| Status PassManager::AddPass(const string &pass_name, GraphPass *pass) { | |||||
| GE_CHECK_NOTNULL(pass); | |||||
| names_to_graph_passes_.emplace_back(pass_name, pass); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status PassManager::Run(const ComputeGraphPtr &graph) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| return Run(graph, names_to_graph_passes_); | |||||
| } | |||||
| Status PassManager::Run(const ComputeGraphPtr &graph, vector<std::pair<std::string, GraphPass *>> &names_to_passes) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| bool not_changed = true; | |||||
| for (auto &pass_pair : names_to_passes) { | |||||
| const auto &pass = pass_pair.second; | |||||
| const auto &pass_name = pass_pair.first; | |||||
| GE_CHECK_NOTNULL(pass); | |||||
| PARSER_TIMESTAMP_START(PassRun); | |||||
| Status status = pass->Run(graph); | |||||
| if (status == SUCCESS) { | |||||
| not_changed = false; | |||||
| } else if (status != NOT_CHANGED) { | |||||
| GELOGE(status, "Pass Run failed on graph %s", graph->GetName().c_str()); | |||||
| return status; | |||||
| } | |||||
| for (const auto &subgraph :graph->GetAllSubgraphs()) { | |||||
| GE_CHECK_NOTNULL(subgraph); | |||||
| GE_CHK_STATUS_RET(pass->ClearStatus(), "pass clear status failed for subgraph %s", subgraph->GetName().c_str()); | |||||
| string subgraph_pass_name = pass_name + "::" + graph->GetName(); | |||||
| PARSER_TIMESTAMP_START(PassRunSubgraph); | |||||
| status = pass->Run(subgraph); | |||||
| PARSER_TIMESTAMP_END(PassRunSubgraph, subgraph_pass_name.c_str()); | |||||
| if (status == SUCCESS) { | |||||
| not_changed = false; | |||||
| } else if (status != NOT_CHANGED) { | |||||
| GELOGE(status, "Pass Run failed on subgraph %s", subgraph->GetName().c_str()); | |||||
| return status; | |||||
| } | |||||
| } | |||||
| PARSER_TIMESTAMP_END(PassRun, pass_name.c_str()); | |||||
| } | |||||
| return not_changed ? NOT_CHANGED : SUCCESS; | |||||
| } | |||||
| PassManager::~PassManager() { | |||||
| for (auto &pass_pair : names_to_graph_passes_) { | |||||
| auto &pass = pass_pair.second; | |||||
| GE_DELETE_NEW_SINGLE(pass); | |||||
| } | |||||
| } | |||||
| } // namespace parser | |||||
| } // namespace ge | |||||
| @@ -1,76 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_COMMON_PASS_MANAGER_H_ | |||||
| #define PARSER_COMMON_PASS_MANAGER_H_ | |||||
| #include <vector> | |||||
| #include "inc/graph_pass.h" | |||||
| using std::vector; | |||||
| namespace ge { | |||||
| namespace parser { | |||||
| /// | |||||
| /// @ingroup domi_omg | |||||
| /// @brief pass manager | |||||
| /// @author | |||||
| /// | |||||
| class PassManager { | |||||
| public: | |||||
| /// | |||||
| /// get graph passes | |||||
| /// @author | |||||
| /// | |||||
| const vector<std::pair<std::string, GraphPass *>> &GraphPasses() const; | |||||
| /// | |||||
| /// Add graph pass | |||||
| /// @param [in] pass Pass to be added, it will be destroyed when pass manager destroys. | |||||
| /// @author | |||||
| /// | |||||
| Status AddPass(const string &pass_name, GraphPass *pass); | |||||
| /// | |||||
| /// Optimize graph with added pass | |||||
| /// @param [inout] graph graph to be optimized | |||||
| /// @return SUCCESS optimize successfully | |||||
| /// @return NOT_CHANGED not optimized | |||||
| /// @return others optimize failed | |||||
| /// @author | |||||
| /// | |||||
| Status Run(const ge::ComputeGraphPtr &graph); | |||||
| /// | |||||
| /// Optimize graph with specified pass | |||||
| /// @param [inout] graph graph to be optimized | |||||
| /// @param [in] passes passes to be used | |||||
| /// @return SUCCESS optimize successfully | |||||
| /// @return NOT_CHANGED not optimized | |||||
| /// @return others optimized failed | |||||
| /// @author | |||||
| /// | |||||
| static Status Run(const ge::ComputeGraphPtr &graph, vector<std::pair<std::string, GraphPass *>> &passes); | |||||
| ~PassManager(); | |||||
| private: | |||||
| vector<std::pair<std::string, GraphPass *>> names_to_graph_passes_; | |||||
| }; | |||||
| } // namespace parser | |||||
| } // namespace ge | |||||
| #endif // PARSER_COMMON_PASS_MANAGER_H_ | |||||
| @@ -23,7 +23,6 @@ | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "omg/omg.h" | #include "omg/omg.h" | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "parser/common/model_saver.h" | |||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -56,7 +55,7 @@ void PreChecker::Init() { | |||||
| fmk_op_types_ = nullptr; | fmk_op_types_ = nullptr; | ||||
| // Currently only Caffe and tensorflow are supported | // Currently only Caffe and tensorflow are supported | ||||
| domi::FrameworkType fmk_type = GetParserContext().type; | |||||
| domi::FrameworkType fmk_type = domi::GetContext().type; | |||||
| if (fmk_type == domi::CAFFE) | if (fmk_type == domi::CAFFE) | ||||
| fmk_op_types_ = &caffe_op_map; | fmk_op_types_ = &caffe_op_map; | ||||
| else if (fmk_type == domi::TENSORFLOW) | else if (fmk_type == domi::TENSORFLOW) | ||||
| @@ -119,8 +118,8 @@ FMK_FUNC_HOST_VISIBILITY Status PreChecker::CheckType(OpId id, bool is_tensorflo | |||||
| // If the user explicitly specifies the mapping relationship of the operator type through | // If the user explicitly specifies the mapping relationship of the operator type through | ||||
| // the -- OP_name_map parameter, the type specified by the user is used. | // the -- OP_name_map parameter, the type specified by the user is used. | ||||
| auto op_map_iter = GetParserContext().op_conf_map.find(type); | |||||
| if (op_map_iter != GetParserContext().op_conf_map.end()) { | |||||
| auto op_map_iter = domi::GetContext().op_conf_map.find(type); | |||||
| if (op_map_iter != domi::GetContext().op_conf_map.end()) { | |||||
| type = op_map_iter->second; | type = op_map_iter->second; | ||||
| } | } | ||||
| @@ -233,7 +232,7 @@ Status PreChecker::Save(string file) { | |||||
| } | } | ||||
| // Save JSON data to a file | // Save JSON data to a file | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(ge::parser::ModelSaver::SaveJsonToFile(file.c_str(), model), "Save failed."); | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(ModelSaver::SaveJsonToFile(file.c_str(), model), "Save failed."); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "common/types.h" | |||||
| #include "omg/omg_inner_types.h" | #include "omg/omg_inner_types.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -45,9 +45,6 @@ message AippOpParams { | |||||
| // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | ||||
| uint32 related_input_rank = 2; | uint32 related_input_rank = 2; | ||||
| // related_input_name is optional and the top name of data node which inserts aipp | |||||
| string related_input_name = 6; | |||||
| // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | ||||
| // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | ||||
| // 配置值 <= Data算子输出边的个数。 | // 配置值 <= Data算子输出边的个数。 | ||||
| @@ -27,7 +27,6 @@ | |||||
| #include "common/types.h" | #include "common/types.h" | ||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "ge/ge_api_types.h" | #include "ge/ge_api_types.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| @@ -159,7 +158,7 @@ bool SaveIdentifierOpMapInfo(const string &line, std::map<int, std::pair<string | |||||
| } | } | ||||
| bool CheckRealPath(const char *file_path) { | bool CheckRealPath(const char *file_path) { | ||||
| string dest_path = ge::parser::RealPath(file_path); | |||||
| string dest_path = ge::RealPath(file_path); | |||||
| if (dest_path.empty()) { | if (dest_path.empty()) { | ||||
| GELOGW("Path [%s] is not real existed.", file_path); | GELOGW("Path [%s] is not real existed.", file_path); | ||||
| return false; | return false; | ||||
| @@ -185,7 +184,7 @@ Status ProtoFileParser::CreatProtoFile() { | |||||
| fusion_proto_path += "/" + CreatTmpName(kTmpFileNameLen); | fusion_proto_path += "/" + CreatTmpName(kTmpFileNameLen); | ||||
| } | } | ||||
| int fd = open(fusion_proto_path.c_str(), O_RDWR | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP); | |||||
| int fd = open(fusion_proto_path.c_str(), O_RDWR | O_CREAT | O_TRUNC, 0640); | |||||
| if (fd < kOpenRetValue) { | if (fd < kOpenRetValue) { | ||||
| GELOGE(FAILED, "creat tmp proto file[%s] failed.", fusion_proto_path.c_str()); | GELOGE(FAILED, "creat tmp proto file[%s] failed.", fusion_proto_path.c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
| #include "common/op_map.h" | #include "common/op_map.h" | ||||
| #include "common/util.h" | #include "common/util.h" | ||||
| @@ -38,6 +38,8 @@ FMK_FUNC_HOST_VISIBILITY OpRegistrationTbe *OpRegistrationTbe::Instance() { | |||||
| } | } | ||||
| bool OpRegistrationTbe::Finalize(const OpRegistrationData ®_data, bool is_train) { | bool OpRegistrationTbe::Finalize(const OpRegistrationData ®_data, bool is_train) { | ||||
| ge::OpTypeContainer::Instance()->Register(reg_data.GetOmOptype()); | |||||
| static std::map<domi::FrameworkType, std::map<std::string, std::string> *> op_map = {{CAFFE, &caffe_op_map}}; | static std::map<domi::FrameworkType, std::map<std::string, std::string> *> op_map = {{CAFFE, &caffe_op_map}}; | ||||
| if (is_train) { | if (is_train) { | ||||
| op_map[domi::TENSORFLOW] = &tensorflow_train_op_map; | op_map[domi::TENSORFLOW] = &tensorflow_train_op_map; | ||||
| @@ -55,7 +57,8 @@ bool OpRegistrationTbe::Finalize(const OpRegistrationData ®_data, bool is_tra | |||||
| continue; | continue; | ||||
| } else { | } else { | ||||
| (*fmk_op_map)[tmp] = reg_data.GetOmOptype(); | (*fmk_op_map)[tmp] = reg_data.GetOmOptype(); | ||||
| GELOGD("First register in parser initialize, original type: %s, om_optype: %s, imply type: %s.", tmp.c_str(), | |||||
| GELOGD("First register in parser initilize, original type: %s, om_optype: %s, imply type: %s.", | |||||
| tmp.c_str(), | |||||
| reg_data.GetOmOptype().c_str(), TypeUtils::ImplyTypeToSerialString(reg_data.GetImplyType()).c_str()); | reg_data.GetOmOptype().c_str(), TypeUtils::ImplyTypeToSerialString(reg_data.GetImplyType()).c_str()); | ||||
| } | } | ||||
| } | } | ||||
| @@ -79,7 +82,7 @@ bool OpRegistrationTbe::RegisterParser(const OpRegistrationData ®_data) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| std::shared_ptr<TensorFlowCustomParserAdapter> tf_parser_adapter = | std::shared_ptr<TensorFlowCustomParserAdapter> tf_parser_adapter = | ||||
| ge::parser::MakeShared<TensorFlowCustomParserAdapter>(); | |||||
| ge::MakeShared<TensorFlowCustomParserAdapter>(); | |||||
| if (tf_parser_adapter == nullptr) { | if (tf_parser_adapter == nullptr) { | ||||
| GELOGE(PARAM_INVALID, "Create tf parser adapter failed."); | GELOGE(PARAM_INVALID, "Create tf parser adapter failed."); | ||||
| return false; | return false; | ||||
| @@ -94,20 +97,22 @@ bool OpRegistrationTbe::RegisterParser(const OpRegistrationData ®_data) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| GELOGI("Register fusion custom op parser: %s", reg_data.GetOmOptype().c_str()); | GELOGI("Register fusion custom op parser: %s", reg_data.GetOmOptype().c_str()); | ||||
| std::shared_ptr<TensorFlowFusionCustomParserAdapter> tf_fusion_parser_adapter = | |||||
| ge::parser::MakeShared<TensorFlowFusionCustomParserAdapter>(); | |||||
| std::shared_ptr<TensorFlowFusionCustomParserAdapter> | |||||
| tf_fusion_parser_adapter = ge::MakeShared<TensorFlowFusionCustomParserAdapter>(); | |||||
| if (tf_fusion_parser_adapter == nullptr) { | if (tf_fusion_parser_adapter == nullptr) { | ||||
| GELOGE(PARAM_INVALID, "Create tf fusion parser adapter failed."); | GELOGE(PARAM_INVALID, "Create tf fusion parser adapter failed."); | ||||
| return false; | return false; | ||||
| } | } | ||||
| OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar( | OpParserRegisterar registerar __attribute__((unused)) = OpParserRegisterar( | ||||
| domi::TENSORFLOW, reg_data.GetOmOptype(), | domi::TENSORFLOW, reg_data.GetOmOptype(), | ||||
| [=]() -> std::shared_ptr<OpParser> { return tf_fusion_parser_adapter; }, true); | |||||
| [=]() -> std::shared_ptr<OpParser> { return tf_fusion_parser_adapter; }, | |||||
| true); | |||||
| } | } | ||||
| } else { | } else { | ||||
| std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(reg_data.GetFrameworkType()); | std::shared_ptr<OpParserFactory> factory = OpParserFactory::Instance(reg_data.GetFrameworkType()); | ||||
| if (factory == nullptr) { | if (factory == nullptr) { | ||||
| GELOGE(INTERNAL_ERROR, "Get op parser factory for %s failed.", | |||||
| GELOGE(INTERNAL_ERROR, | |||||
| "Get op parser factory for %s failed.", | |||||
| TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str()); | TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str()); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -119,12 +124,13 @@ bool OpRegistrationTbe::RegisterParser(const OpRegistrationData ®_data) { | |||||
| PARSER_CREATOR_FN func = CustomParserAdapterRegistry::Instance()->GetCreateFunc(reg_data.GetFrameworkType()); | PARSER_CREATOR_FN func = CustomParserAdapterRegistry::Instance()->GetCreateFunc(reg_data.GetFrameworkType()); | ||||
| if (func == nullptr) { | if (func == nullptr) { | ||||
| GELOGE(INTERNAL_ERROR, "Get custom parser adapter failed for fmk type %s.", | |||||
| GELOGW("Get custom parser adapter failed for fmk type %s.", | |||||
| TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str()); | TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str()); | ||||
| return false; | return false; | ||||
| } | } | ||||
| OpParserFactory::Instance(reg_data.GetFrameworkType())->RegisterCreator(reg_data.GetOmOptype(), func); | OpParserFactory::Instance(reg_data.GetFrameworkType())->RegisterCreator(reg_data.GetOmOptype(), func); | ||||
| GELOGD("Register custom parser adapter for op %s of fmk type %s success.", reg_data.GetOmOptype().c_str(), | |||||
| GELOGD("Register custom parser adapter for op %s of fmk type %s success.", | |||||
| reg_data.GetOmOptype().c_str(), | |||||
| TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str()); | TypeUtils::FmkTypeToSerialString(reg_data.GetFrameworkType()).c_str()); | ||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -1,212 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tbe_plugin_loader.h" | |||||
| #include <dirent.h> | |||||
| #include <sys/stat.h> | |||||
| #include <unistd.h> | |||||
| #include <algorithm> | |||||
| #include <cstring> | |||||
| #include <fstream> | |||||
| #include <iostream> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/string_util.h" | |||||
| #include "framework/omg/parser/parser_inner_ctx.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| namespace ge { | |||||
| std::map<string, string> TBEPluginLoader::options_ = {}; | |||||
| namespace { | |||||
| const std::string FRAMEWORK_TYPE = "ge.frameworkType"; | |||||
| } | |||||
| // Get Singleton Instance | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY TBEPluginLoader &TBEPluginLoader::Instance() { | |||||
| static TBEPluginLoader instance_ptr_; | |||||
| return instance_ptr_; | |||||
| } | |||||
| Status TBEPluginLoader::ClearHandles_() { | |||||
| Status ret = SUCCESS; | |||||
| for (const auto &handle : handles_vec_) { | |||||
| if (dlclose(handle) != 0) { | |||||
| ret = FAILED; | |||||
| GELOGW("Failed to close handle: %s", dlerror()); | |||||
| } | |||||
| } | |||||
| handles_vec_.clear(); | |||||
| return ret; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status TBEPluginLoader::Finalize() { | |||||
| Status ret = ClearHandles_(); | |||||
| return ret; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginLoader::LoadPluginSo( | |||||
| const std::map<string, string> &options) { | |||||
| vector<string> file_list; | |||||
| string caffe_parser_path; | |||||
| std::string plugin_path; | |||||
| options_ = options; | |||||
| GetCustomOpPath(plugin_path); | |||||
| // Whether there are files in the plugin so path | |||||
| GetPluginSoFileList(plugin_path, file_list, caffe_parser_path); | |||||
| // No file | |||||
| if (file_list.empty()) { | |||||
| // Print log | |||||
| GELOGW("Can not find any plugin file in plugin_path: %s", plugin_path.c_str()); | |||||
| } | |||||
| GELOGW("The shared library will not be checked. Please ensure that the source of the shared library is trusted."); | |||||
| // Load other so files except lib_caffe_parser.so in the plugin so path | |||||
| for (auto elem : file_list) { | |||||
| StringUtils::Trim(elem); | |||||
| void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL | RTLD_NODELETE); | |||||
| if (handle == nullptr) { | |||||
| GELOGW("dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); | |||||
| } else if (find(handles_vec_.begin(), handles_vec_.end(), handle) == handles_vec_.end()) { | |||||
| // Close dl when the program exist, not close here | |||||
| GELOGI("Plugin load %s success.", elem.c_str()); | |||||
| handles_vec_.push_back(handle); | |||||
| } else { | |||||
| GELOGI("Plugin so has already been loaded, no need to load again."); | |||||
| } | |||||
| } | |||||
| } | |||||
| void TBEPluginLoader::GetCustomOpPath(std::string &customop_path) { | |||||
| GELOGI("Enter get custom op path schedule"); | |||||
| std::string fmk_type; | |||||
| domi::FrameworkType type = domi::TENSORFLOW; | |||||
| auto it = options_.find(FRAMEWORK_TYPE); | |||||
| if (it != options_.end()) { | |||||
| type = static_cast<domi::FrameworkType>(std::strtol(it->second.c_str(), nullptr, 10)); | |||||
| } | |||||
| fmk_type = ge::TypeUtils::FmkTypeToSerialString(type); | |||||
| GELOGI("Framework type is %s.", fmk_type.c_str()); | |||||
| const char *path_env = std::getenv("ASCEND_OPP_PATH"); | |||||
| if (path_env != nullptr) { | |||||
| std::string path = path_env; | |||||
| customop_path = (path + "/framework/custom" + "/:") + (path + "/framework/built-in/" + fmk_type); | |||||
| GELOGI("Get custom so path from env : %s", path_env); | |||||
| return; | |||||
| } | |||||
| std::string path_base = GetPath(); | |||||
| GELOGI("path_base is %s", path_base.c_str()); | |||||
| path_base = path_base.substr(0, path_base.rfind('/')); | |||||
| path_base = path_base.substr(0, path_base.rfind('/') + 1); | |||||
| customop_path = (path_base + "ops/framework/custom" + "/:") + (path_base + "ops/framework/built-in/" + fmk_type); | |||||
| } | |||||
| string TBEPluginLoader::GetPath() { | |||||
| Dl_info dl_info; | |||||
| if (dladdr(reinterpret_cast<void *>(&TBEPluginLoader::GetPath), &dl_info) == 0) { | |||||
| GELOGW("Failed to read so path!"); | |||||
| return string(); | |||||
| } else { | |||||
| string so_path = dl_info.dli_fname; | |||||
| char path[PATH_MAX] = {0}; | |||||
| if (so_path.length() >= PATH_MAX) { | |||||
| GELOGW("File path is too long!"); | |||||
| return string(); | |||||
| } | |||||
| if (realpath(so_path.c_str(), path) == nullptr) { | |||||
| GELOGW("Failed to get realpath of %s", so_path.c_str()); | |||||
| return string(); | |||||
| } | |||||
| so_path = path; | |||||
| so_path = so_path.substr(0, so_path.rfind('/') + 1); | |||||
| return so_path; | |||||
| } | |||||
| } | |||||
| void TBEPluginLoader::GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path) { | |||||
| // Support to split multiple so directories by ":" | |||||
| vector<string> v_path = StringUtils::Split(path, ':'); | |||||
| for (size_t i = 0; i < v_path.size(); ++i) { | |||||
| FindParserSo(v_path[i], file_list, caffe_parser_path); | |||||
| GELOGI("CustomOpLib full name = %s", v_path[i].c_str()); | |||||
| } | |||||
| } | |||||
| void TBEPluginLoader::FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path) { | |||||
| // Path, change to absolute path | |||||
| string real_path = ge::parser::RealPath(path.c_str()); | |||||
| // Plugin path does not exist | |||||
| if (real_path.empty()) { | |||||
| GELOGW("RealPath is empty."); | |||||
| return; | |||||
| } | |||||
| struct stat stat_buf; | |||||
| if ((stat(real_path.c_str(), &stat_buf) != 0) || (!S_ISDIR(stat_buf.st_mode))) { | |||||
| GELOGW("%s is not a dir.", real_path.c_str()); | |||||
| return; | |||||
| } | |||||
| struct dirent *dent(0); | |||||
| DIR *dir = opendir(real_path.c_str()); | |||||
| // Plugin path does not exist | |||||
| if (dir == nullptr) { | |||||
| GELOGW("Open directory %s failed.", real_path.c_str()); | |||||
| return; | |||||
| } | |||||
| while ((dent = readdir(dir)) != nullptr) { | |||||
| if (strcmp(dent->d_name, ".") == 0 || strcmp(dent->d_name, "..") == 0) continue; | |||||
| string name = dent->d_name; | |||||
| string full_name = real_path + "/" + name; | |||||
| const string so_suff = ".so"; | |||||
| const string caffe_parser_so_suff = "lib_caffe_parser.so"; | |||||
| const string aicpu_so_suff = "_aicpu.so"; | |||||
| const string aicpu_host_so_suff = "_online.so"; | |||||
| if (name.size() >= so_suff.size() && name.compare(name.size() - so_suff.size(), so_suff.size(), so_suff) == 0) { | |||||
| ProcessSoFullName(file_list, caffe_parser_path, full_name, caffe_parser_so_suff, aicpu_so_suff, | |||||
| aicpu_host_so_suff); | |||||
| } else { | |||||
| FindParserSo(full_name, file_list, caffe_parser_path); | |||||
| } | |||||
| } | |||||
| closedir(dir); | |||||
| } | |||||
| void TBEPluginLoader::ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name, | |||||
| const string &caffe_parser_so_suff, const string &aicpu_so_suff, | |||||
| const string &aicpu_host_so_suff) { | |||||
| if (full_name.size() >= caffe_parser_so_suff.size() && | |||||
| full_name.compare(full_name.size() - caffe_parser_so_suff.size(), caffe_parser_so_suff.size(), | |||||
| caffe_parser_so_suff) == 0) { | |||||
| caffe_parser_path = full_name; | |||||
| } else { | |||||
| // Save parser so path into file_list vector | |||||
| file_list.push_back(full_name); | |||||
| } | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,62 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_COMMON_TBE_PLUGIN_LOADER_H_ | |||||
| #define PARSER_COMMON_TBE_PLUGIN_LOADER_H_ | |||||
| #include <dlfcn.h> | |||||
| #include <functional> | |||||
| #include <iostream> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <type_traits> | |||||
| #include <typeinfo> | |||||
| #include <vector> | |||||
| #include "external/ge/ge_api_error_codes.h" | |||||
| #include "external/register/register.h" | |||||
| namespace ge { | |||||
| using SoHandlesVec = std::vector<void *>; | |||||
| class TBEPluginLoader { | |||||
| public: | |||||
| Status Finalize(); | |||||
| // Get TBEPluginManager singleton instance | |||||
| static TBEPluginLoader& Instance(); | |||||
| void LoadPluginSo(const std::map<string, string> &options); | |||||
| static string GetPath(); | |||||
| private: | |||||
| TBEPluginLoader() = default; | |||||
| ~TBEPluginLoader() = default; | |||||
| Status ClearHandles_(); | |||||
| static void ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name, | |||||
| const string &caffe_parser_so_suff, const string &aicpu_so_suff, | |||||
| const string &aicpu_host_so_suff); | |||||
| static void GetCustomOpPath(std::string &customop_path); | |||||
| static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path); | |||||
| static void FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path); | |||||
| SoHandlesVec handles_vec_; | |||||
| static std::map<string, string> options_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif //PARSER_COMMON_TBE_PLUGIN_LOADER_H_ | |||||
| @@ -1,78 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/thread_pool.h" | |||||
| #include <atomic> | |||||
| #include <functional> | |||||
| #include <queue> | |||||
| #include <stdexcept> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "register/register_types.h" | |||||
| namespace ge { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::ThreadPool(uint32_t size) : is_stoped_(false) { | |||||
| idle_thrd_num_ = size < 1 ? 1 : size; | |||||
| for (uint32_t i = 0; i < idle_thrd_num_; ++i) { | |||||
| pool_.emplace_back(ThreadFunc, this); | |||||
| } | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::~ThreadPool() { | |||||
| is_stoped_.store(true); | |||||
| { | |||||
| std::unique_lock<std::mutex> lock{m_lock_}; | |||||
| cond_var_.notify_all(); | |||||
| } | |||||
| for (std::thread &thd : pool_) { | |||||
| if (thd.joinable()) { | |||||
| try { | |||||
| thd.join(); | |||||
| } catch (const std::system_error &) { | |||||
| GELOGW("system_error"); | |||||
| } catch (...) { | |||||
| GELOGW("exception"); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void ThreadPool::ThreadFunc(ThreadPool *thread_pool) { | |||||
| if (thread_pool == nullptr) { | |||||
| return; | |||||
| } | |||||
| while (!thread_pool->is_stoped_) { | |||||
| std::function<void()> task; | |||||
| { | |||||
| std::unique_lock<std::mutex> lock{thread_pool->m_lock_}; | |||||
| thread_pool->cond_var_.wait( | |||||
| lock, [thread_pool] { return thread_pool->is_stoped_.load() || !thread_pool->tasks_.empty(); }); | |||||
| if (thread_pool->is_stoped_ && thread_pool->tasks_.empty()) { | |||||
| return; | |||||
| } | |||||
| task = std::move(thread_pool->tasks_.front()); | |||||
| thread_pool->tasks_.pop(); | |||||
| } | |||||
| --thread_pool->idle_thrd_num_; | |||||
| task(); | |||||
| ++thread_pool->idle_thrd_num_; | |||||
| } | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,83 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PARSER_COMMON_THREAD_POOL_H_ | |||||
| #define PARSER_COMMON_THREAD_POOL_H_ | |||||
| #include <atomic> | |||||
| #include <condition_variable> | |||||
| #include <functional> | |||||
| #include <future> | |||||
| #include <memory> | |||||
| #include <queue> | |||||
| #include <stdexcept> | |||||
| #include <thread> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| #include "external/ge/ge_api_error_codes.h" | |||||
| #include "graph/types.h" | |||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| namespace ge { | |||||
| using ThreadTask = std::function<void()>; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ThreadPool { | |||||
| public: | |||||
| explicit ThreadPool(uint32_t size = 4); | |||||
| ~ThreadPool(); | |||||
| template <class Func, class... Args> | |||||
| auto commit(Func &&func, Args &&... args) -> std::future<decltype(func(args...))> { | |||||
| GELOGD("commit run task enter."); | |||||
| using retType = decltype(func(args...)); | |||||
| std::future<retType> fail_future; | |||||
| if (is_stoped_.load()) { | |||||
| GELOGE(ge::FAILED, "thread pool has been stopped."); | |||||
| return fail_future; | |||||
| } | |||||
| auto bindFunc = std::bind(std::forward<Func>(func), std::forward<Args>(args)...); | |||||
| auto task = ge::parser::MakeShared<std::packaged_task<retType()>>(bindFunc); | |||||
| if (task == nullptr) { | |||||
| GELOGE(ge::FAILED, "Make shared failed."); | |||||
| return fail_future; | |||||
| } | |||||
| std::future<retType> future = task->get_future(); | |||||
| { | |||||
| std::lock_guard<std::mutex> lock{m_lock_}; | |||||
| tasks_.emplace([task]() { (*task)(); }); | |||||
| } | |||||
| cond_var_.notify_one(); | |||||
| GELOGD("commit run task end"); | |||||
| return future; | |||||
| } | |||||
| static void ThreadFunc(ThreadPool *thread_pool); | |||||
| private: | |||||
| std::vector<std::thread> pool_; | |||||
| std::queue<ThreadTask> tasks_; | |||||
| std::mutex m_lock_; | |||||
| std::condition_variable cond_var_; | |||||
| std::atomic<bool> is_stoped_; | |||||
| std::atomic<uint32_t> idle_thrd_num_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // PARSER_COMMON_THREAD_POOL_H_ | |||||
| @@ -2,7 +2,7 @@ include $(BUILD_SYSTEM)/base_rules.mk | |||||
| FUNCTION_TO_GRAPH_OUT_TIMESTAMP := $(HOST_OUT_ROOT)/func_to_graph/.timestamp | FUNCTION_TO_GRAPH_OUT_TIMESTAMP := $(HOST_OUT_ROOT)/func_to_graph/.timestamp | ||||
| PROTO_SRC_DIR = parser/parser/func_to_graph/proto | |||||
| PROTO_SRC_DIR = framework/domi/parser/func_to_graph/proto | |||||
| PY_PROTO_BUILD_DIR = $(HOST_OUT_ROOT)/tmp/function_to_graph/proto | PY_PROTO_BUILD_DIR = $(HOST_OUT_ROOT)/tmp/function_to_graph/proto | ||||
| $(warning PRIVATE_PROTOC is $(PRIVATE_PROTOC)) | $(warning PRIVATE_PROTOC is $(PRIVATE_PROTOC)) | ||||
| @@ -14,4 +14,4 @@ $(FUNCTION_TO_GRAPH_OUT_TIMESTAMP): $(PRIVATE_PROTOC) | |||||
| $(LOCAL_BUILT_MODULE): $(FUNCTION_TO_GRAPH_OUT_TIMESTAMP) | $(LOCAL_BUILT_MODULE): $(FUNCTION_TO_GRAPH_OUT_TIMESTAMP) | ||||
| mkdir -p $@ | mkdir -p $@ | ||||
| cp -rf $(PY_PROTO_BUILD_DIR)/* $@ | |||||
| cp -rf $(PY_PROTO_BUILD_DIR)/* $@ | |||||
| @@ -1,6 +1,6 @@ | |||||
| LOCAL_PATH := $(call my-dir) | LOCAL_PATH := $(call my-dir) | ||||
| include $(LOCAL_PATH)/stub/Makefile | |||||
| include $(LOCAL_PATH)/../stub/Makefile | |||||
| COMMON_LOCAL_C_INCLUDES := \ | COMMON_LOCAL_C_INCLUDES := \ | ||||
| proto/om.proto \ | proto/om.proto \ | ||||
| proto/insert_op.proto \ | proto/insert_op.proto \ | ||||
| @@ -39,9 +39,7 @@ COMMON_LOCAL_C_INCLUDES := \ | |||||
| $(TOPDIR)inc/external/graph \ | $(TOPDIR)inc/external/graph \ | ||||
| $(TOPDIR)inc/external/parser \ | $(TOPDIR)inc/external/parser \ | ||||
| $(TOPDIR)inc/framework \ | $(TOPDIR)inc/framework \ | ||||
| $(TOPDIR)parser/parser \ | |||||
| $(TOPDIR)parser \ | |||||
| $(TOPDIR)graphengine/ge \ | |||||
| $(TOPDIR)framework/domi/parser \ | |||||
| libc_sec/include \ | libc_sec/include \ | ||||
| third_party/protobuf/include \ | third_party/protobuf/include \ | ||||
| third_party/json/include \ | third_party/json/include \ | ||||
| @@ -115,6 +113,7 @@ LOCAL_SHARED_LIBRARIES := \ | |||||
| libparser_common \ | libparser_common \ | ||||
| libgraph \ | libgraph \ | ||||
| libregister \ | libregister \ | ||||
| libge_common \ | |||||
| lib_caffe_parser \ | lib_caffe_parser \ | ||||
| LOCAL_LDFLAGS := -lrt | LOCAL_LDFLAGS := -lrt | ||||
| @@ -134,8 +133,8 @@ endif | |||||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | ||||
| LOCAL_SRC_FILES := ../../out/parser/lib64/stub/tensorflow_parser.cc | |||||
| LOCAL_SRC_FILES += ../../out/parser/lib64/stub/caffe_parser.cc | |||||
| LOCAL_SRC_FILES := ../../../out/ge/lib64/stub/tensorflow_parser.cc | |||||
| LOCAL_SRC_FILES += ../../../out/ge/lib64/stub/caffe_parser.cc | |||||
| LOCAL_SHARED_LIBRARIES := | LOCAL_SHARED_LIBRARIES := | ||||
| @@ -29,9 +29,7 @@ LOCAL_C_INCLUDES := \ | |||||
| $(TOPDIR)inc/external \ | $(TOPDIR)inc/external \ | ||||
| $(TOPDIR)inc/external/graph \ | $(TOPDIR)inc/external/graph \ | ||||
| $(TOPDIR)inc/framework \ | $(TOPDIR)inc/framework \ | ||||
| $(TOPDIR)parser \ | |||||
| $(TOPDIR)parser/parser \ | |||||
| $(TOPDIR)graphengine/ge \ | |||||
| $(TOPDIR)framework/domi/parser \ | |||||
| libc_sec/include \ | libc_sec/include \ | ||||
| third_party/protobuf/include \ | third_party/protobuf/include \ | ||||
| third_party/json/include \ | third_party/json/include \ | ||||
| @@ -45,6 +43,7 @@ LOCAL_SHARED_LIBRARIES := \ | |||||
| libparser_common \ | libparser_common \ | ||||
| libgraph \ | libgraph \ | ||||
| libregister \ | libregister \ | ||||
| libge_common \ | |||||
| LOCAL_LDFLAGS := -lrt | LOCAL_LDFLAGS := -lrt | ||||
| @@ -17,7 +17,7 @@ | |||||
| #include "onnx_constant_parser.h" | #include "onnx_constant_parser.h" | ||||
| #include <map> | #include <map> | ||||
| #include <vector> | #include <vector> | ||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "framework/omg/parser/parser_inner_ctx.h" | #include "framework/omg/parser/parser_inner_ctx.h" | ||||
| #include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
| @@ -30,7 +30,6 @@ using ge::onnx::TensorProto; | |||||
| using domi::ONNX; | using domi::ONNX; | ||||
| using GeShape = ge::GeShape; | using GeShape = ge::GeShape; | ||||
| using GeTensorDesc = ge::GeTensorDesc; | using GeTensorDesc = ge::GeTensorDesc; | ||||
| using namespace ge::parser; | |||||
| namespace ge { | namespace ge { | ||||
| Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count) { | Status OnnxConstantParser::ParseConvertData(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count) { | ||||
| @@ -22,7 +22,6 @@ | |||||
| #include "parser/onnx/onnx_util.h" | #include "parser/onnx/onnx_util.h" | ||||
| using domi::ONNX; | using domi::ONNX; | ||||
| using namespace ge::parser; | |||||
| namespace ge { | namespace ge { | ||||
| Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def) { | Status OnnxDataParser::ParseParams(const Message *op_src, ge::Operator &op_def) { | ||||
| @@ -18,25 +18,24 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include "common/convert/pb2json.h" | #include "common/convert/pb2json.h" | ||||
| #include "common/model_saver.h" | |||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "external/graph/operator_factory.h" | #include "external/graph/operator_factory.h" | ||||
| #include "external/register/register_error_codes.h" | #include "external/register/register_error_codes.h" | ||||
| #include "framework/omg/parser/parser_inner_ctx.h" | #include "framework/omg/parser/parser_inner_ctx.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "omg/parser/parser_factory.h" | #include "omg/parser/parser_factory.h" | ||||
| #include "onnx_op_parser.h" | #include "onnx_op_parser.h" | ||||
| #include "onnx_util.h" | #include "onnx_util.h" | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "parser/common/pre_checker.h" | #include "parser/common/pre_checker.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "parser/common/model_saver.h" | |||||
| #include "parser/common/parser_utils.h" | |||||
| #include "parser/onnx/onnx_util.h" | #include "parser/onnx/onnx_util.h" | ||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| std::map<std::string, std::string> kOnnxOpMap = { | std::map<std::string, std::string> kOnnxOpMap = { | ||||
| {ge::kOpTypeInput, ge::parser::DATA}, {ge::kOpTypeConstant, ge::parser::CONSTANT}, | |||||
| {ge::kOpTypeInput, ge::DATA}, {ge::kOpTypeConstant, ge::CONSTANT}, | |||||
| }; | }; | ||||
| } | } | ||||
| @@ -256,9 +255,11 @@ Status OnnxModelParser::SetOperatorInputs() { | |||||
| for (auto in_iter = inputs_map_.begin(); in_iter != inputs_map_.end(); in_iter++) { | for (auto in_iter = inputs_map_.begin(); in_iter != inputs_map_.end(); in_iter++) { | ||||
| auto out_iter = outputs_map_.find(in_iter->first); | auto out_iter = outputs_map_.find(in_iter->first); | ||||
| if (out_iter == outputs_map_.end()) { | if (out_iter == outputs_map_.end()) { | ||||
| GELOGE(INTERNAL_ERROR, "Unknown input: %s:%d in node: %s", in_iter->first.c_str(), in_iter->second[0].second, | |||||
| GELOGW("Unknown input: %s:%d for node: %s, which maybe option input.", | |||||
| in_iter->first.c_str(), | |||||
| in_iter->second[0].second, | |||||
| in_iter->second[0].first.c_str()); | in_iter->second[0].first.c_str()); | ||||
| return INTERNAL_ERROR; | |||||
| continue; | |||||
| } | } | ||||
| std::vector<std::pair<std::string, int>> &input_node_indexs = in_iter->second; | std::vector<std::pair<std::string, int>> &input_node_indexs = in_iter->second; | ||||
| @@ -438,7 +439,7 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { | |||||
| // 1. Get graph from onnx model file. | // 1. Get graph from onnx model file. | ||||
| ge::onnx::ModelProto onnx_model; | ge::onnx::ModelProto onnx_model; | ||||
| if (!ge::parser::ReadProtoFromBinaryFile(file, &onnx_model)) { | |||||
| if (!ge::ReadProtoFromBinaryFile(file, &onnx_model)) { | |||||
| GELOGE(PARAM_INVALID, "Read onnx model file failed."); | GELOGE(PARAM_INVALID, "Read onnx model file failed."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -535,6 +536,8 @@ Status OnnxModelParser::Parse(const char *file, ge::Graph &graph) { | |||||
| } | } | ||||
| graph.SetInputs(input_ops).SetOutputs(output_indexs); | graph.SetInputs(input_ops).SetOutputs(output_indexs); | ||||
| GE_RETURN_IF_ERROR(ParserUtils::ExpandOneToManyGraph(graph)); | |||||
| UpdateFormat(graph); | UpdateFormat(graph); | ||||
| GELOGI("Onnx model parser success."); | GELOGI("Onnx model parser success."); | ||||
| @@ -552,12 +555,12 @@ Status OnnxModelParser::ToJson(const char *model_file, const char *json_file) { | |||||
| } | } | ||||
| ge::onnx::ModelProto onnx_model; | ge::onnx::ModelProto onnx_model; | ||||
| GE_RETURN_WITH_LOG_IF_FALSE(ge::parser::ReadProtoFromBinaryFile(model_file, &onnx_model), | |||||
| GE_RETURN_WITH_LOG_IF_FALSE(ge::ReadProtoFromBinaryFile(model_file, &onnx_model), | |||||
| "ReadProtoFromBinaryFile failed, file:%s.", model_file); | "ReadProtoFromBinaryFile failed, file:%s.", model_file); | ||||
| ge::onnx::GraphProto graph_proto = onnx_model.graph(); | ge::onnx::GraphProto graph_proto = onnx_model.graph(); | ||||
| nlohmann::json j; | nlohmann::json j; | ||||
| ge::Pb2Json::Message2Json(graph_proto, std::set<std::string>(), j, true); | ge::Pb2Json::Message2Json(graph_proto, std::set<std::string>(), j, true); | ||||
| return ge::parser::ModelSaver::SaveJsonToFile(json_file, j); | |||||
| return ge::ModelSaver::SaveJsonToFile(json_file, j); | |||||
| } | } | ||||
| ge::DataType OnnxModelParser::ConvertToGeDataType(const uint32_t type) { | ge::DataType OnnxModelParser::ConvertToGeDataType(const uint32_t type) { | ||||
| @@ -45,9 +45,6 @@ message AippOpParams { | |||||
| // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | ||||
| uint32 related_input_rank = 2; | uint32 related_input_rank = 2; | ||||
| // related_input_name is optional and the top name of data node which inserts aipp | |||||
| string related_input_name = 6; | |||||
| // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | ||||
| // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | ||||
| // 配置值 <= Data算子输出边的个数。 | // 配置值 <= Data算子输出边的个数。 | ||||
| @@ -18,8 +18,7 @@ | |||||
| #include <iostream> | #include <iostream> | ||||
| #include "common/fmk_error_codes.h" | #include "common/fmk_error_codes.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "common/types.h" | |||||
| #include "common/types_map.h" | #include "common/types_map.h" | ||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
| @@ -219,7 +218,7 @@ domi::Status GraphToFunctionDef::RecordResult(ge::ComputeGraphPtr graph, | |||||
| string op_name = anchor->GetOwnerNode()->GetName() + "_" + to_string(anchor->GetIdx()) + "_retval"; | string op_name = anchor->GetOwnerNode()->GetName() + "_" + to_string(anchor->GetIdx()) + "_retval"; | ||||
| ge::OpDescPtr op = nullptr; | ge::OpDescPtr op = nullptr; | ||||
| GE_MAKE_SHARED(op = std::make_shared<ge::OpDesc>(op_name, ge::parser::NETOUTPUT), return FAILED); | |||||
| GE_MAKE_SHARED(op = std::make_shared<ge::OpDesc>(op_name, ge::NETOUTPUT), return FAILED); | |||||
| graphStatus status = op->AddInputDesc(ge::GeTensorDesc()); | graphStatus status = op->AddInputDesc(ge::GeTensorDesc()); | ||||
| if (status != GRAPH_SUCCESS) { | if (status != GRAPH_SUCCESS) { | ||||
| GELOGE(FAILED, "Add input desc for op:%s failed.", op->GetName().c_str()); | GELOGE(FAILED, "Add input desc for op:%s failed.", op->GetName().c_str()); | ||||
| @@ -282,7 +281,7 @@ domi::Status GraphToFunctionDef::RecordArg(ge::ComputeGraphPtr graph, const vect | |||||
| string op_name = anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName() + "_" + | string op_name = anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName() + "_" + | ||||
| to_string(anchor->GetPeerOutAnchor()->GetIdx()) + "_arg"; | to_string(anchor->GetPeerOutAnchor()->GetIdx()) + "_arg"; | ||||
| ge::OpDescPtr op = nullptr; | ge::OpDescPtr op = nullptr; | ||||
| GE_MAKE_SHARED(op = std::make_shared<ge::OpDesc>(op_name, ge::parser::DATA), return FAILED); | |||||
| GE_MAKE_SHARED(op = std::make_shared<ge::OpDesc>(op_name, ge::DATA), return FAILED); | |||||
| graphStatus status = op->AddOutputDesc(ge::GeTensorDesc()); | graphStatus status = op->AddOutputDesc(ge::GeTensorDesc()); | ||||
| if (status != GRAPH_SUCCESS) { | if (status != GRAPH_SUCCESS) { | ||||
| GELOGE(FAILED, "Add output desc for op:%s failed.", op->GetName().c_str()); | GELOGE(FAILED, "Add output desc for op:%s failed.", op->GetName().c_str()); | ||||
| @@ -330,7 +329,7 @@ domi::Status GraphToFunctionDef::DavGraphToFunctionDef(ge::ComputeGraphPtr graph | |||||
| for (const ge::NodePtr &node : graph->GetDirectNode()) { | for (const ge::NodePtr &node : graph->GetDirectNode()) { | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| if (node->GetOpDesc()->GetType() == ge::parser::DATA) { | |||||
| if (node->GetOpDesc()->GetType() == ge::DATA) { | |||||
| int64_t index = 0; | int64_t index = 0; | ||||
| int64_t type = 1; | int64_t type = 1; | ||||
| @@ -351,7 +350,7 @@ domi::Status GraphToFunctionDef::DavGraphToFunctionDef(ge::ComputeGraphPtr graph | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (node->GetOpDesc()->GetType() == ge::parser::NETOUTPUT) { | |||||
| if (node->GetOpDesc()->GetType() == ge::NETOUTPUT) { | |||||
| int64_t index = 0; | int64_t index = 0; | ||||
| int64_t type = 1; | int64_t type = 1; | ||||
| @@ -475,7 +474,7 @@ domi::Status GraphToFunctionDef::BuildFunctionDef(ge::ComputeGraphPtr &graph, co | |||||
| GE_CHECK_NOTNULL(library); | GE_CHECK_NOTNULL(library); | ||||
| GE_CHECK_NOTNULL(call_node_def); | GE_CHECK_NOTNULL(call_node_def); | ||||
| // Current date / time base on the current system | // Current date / time base on the current system | ||||
| string now_time = ge::parser::CurrentTimeInStr(); | |||||
| string now_time = ge::CurrentTimeInStr(); | |||||
| static int i = 0; | static int i = 0; | ||||
| const string name = name_in + now_time + to_string(i); | const string name = name_in + now_time + to_string(i); | ||||
| i++; | i++; | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "common/fmk_types.h" | #include "common/fmk_types.h" | ||||
| #include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "common/types.h" | |||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "graph/node.h" | #include "graph/node.h" | ||||
| #include "graph/types.h" | #include "graph/types.h" | ||||
| @@ -23,13 +23,13 @@ | |||||
| #include "cce/cce.h" | #include "cce/cce.h" | ||||
| #include "cce/dnn.h" | #include "cce/dnn.h" | ||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "common/math/math_util.h" | |||||
| #include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
| #include "common/op_map.h" | #include "common/op_map.h" | ||||
| #include "common/types.h" | |||||
| #include "common/types_map.h" | #include "common/types_map.h" | ||||
| #include "common/util.h" | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/omg/parser/parser_inner_ctx.h" | |||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
| @@ -39,7 +39,6 @@ | |||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "graph_functiondef.h" | #include "graph_functiondef.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "proto/tensorflow/attr_value.pb.h" | #include "proto/tensorflow/attr_value.pb.h" | ||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| @@ -92,137 +91,117 @@ const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal"; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map<string, OpSupportTranInfo> g_OpSupportTranInfo = {}; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::map<string, OpSupportTranInfo> g_OpSupportTranInfo = {}; | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::CAST, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::CAST, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, | |||||
| OutDtSupportUndefined) | OutDtSupportUndefined) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CAST, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::CAST, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, | |||||
| OutDtSupportUndefined) | OutDtSupportUndefined) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ADDN, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::ADDN, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::ADDN, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::ADDN, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ADD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::ADD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::ADD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::ADD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::MUL, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::MUL, | |||||
| std::vector<ge::Format>({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NCHW, ge::FORMAT_NHWC, | std::vector<ge::Format>({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NCHW, ge::FORMAT_NHWC, | ||||
| ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0}), | ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0}), | ||||
| InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) | InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::L2LOSS, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::L2LOSS, | |||||
| std::vector<ge::Format>({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NC1HWC0, ge::FORMAT_NHWC, | std::vector<ge::Format>({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NC1HWC0, ge::FORMAT_NHWC, | ||||
| ge::FORMAT_HWCN}), // inputformats | ge::FORMAT_HWCN}), // inputformats | ||||
| ge::DT_FLOAT, ge::FORMAT_NC1HWC0, ge::DT_FLOAT) | ge::DT_FLOAT, ge::FORMAT_NC1HWC0, ge::DT_FLOAT) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CONVGRADFILTER, InFmtSupportUndefined, InDtSupportUndefined, | |||||
| ge::FORMAT_FRACTAL_Z, ge::DT_FLOAT) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CONV2DBACKPROPINPUT, InFmtSupportUndefined, InDtSupportUndefined, | |||||
| ge::FORMAT_NC1HWC0, ge::DT_FLOAT16) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::BIASADDGRAD, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::CONVGRADFILTER, InFmtSupportUndefined, InDtSupportUndefined, ge::FORMAT_FRACTAL_Z, | |||||
| ge::DT_FLOAT) | ge::DT_FLOAT) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::BIASADD, ge::FORMAT_NCHW, ge::DT_FLOAT, ge::FORMAT_NCHW, ge::DT_FLOAT) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::CONV2DBACKPROPINPUT, InFmtSupportUndefined, InDtSupportUndefined, | |||||
| ge::FORMAT_NC1HWC0, ge::DT_FLOAT16) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::BIASADDGRAD, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, ge::DT_FLOAT) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::BIASADD, ge::FORMAT_NCHW, ge::DT_FLOAT, ge::FORMAT_NCHW, ge::DT_FLOAT) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ACTIVATION, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, | |||||
| ge::DT_FLOAT16) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::ACTIVATIONGRAD, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, | |||||
| ge::DT_FLOAT16) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::SOFTMAX, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::ACTIVATION, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::ACTIVATIONGRAD, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, | |||||
| ge::DT_FLOAT16) | ge::DT_FLOAT16) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SOFTMAX, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::SOFTMAX, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SOFTMAX, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DEPTHWISECONV2DBACKPROPFILTER, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::DEPTHWISECONV2DBACKPROPFILTER, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16, | |||||
| ge::FORMAT_C1HWNCoC0, ge::DT_FLOAT) | ge::FORMAT_C1HWNCoC0, ge::DT_FLOAT) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DEPTHWISECONV2DBACKPORPINPUT, InFmtSupportUndefined, InDtSupportUndefined, | |||||
| OutFmtSupportAsInput, OutDtSupportUndefined) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DEPTHWISECONV2DFORWARDNATIVE, InFmtSupportUndefined, InDtSupportUndefined, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::DEPTHWISECONV2DBACKPORPINPUT, InFmtSupportUndefined, InDtSupportUndefined, | |||||
| OutFmtSupportAsInput, OutDtSupportUndefined) | OutFmtSupportAsInput, OutDtSupportUndefined) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::FUSEDBATCHNORM, InFmtSupportUndefined, InDtSupportUndefined, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::DEPTHWISECONV2DFORWARDNATIVE, InFmtSupportUndefined, InDtSupportUndefined, | |||||
| OutFmtSupportAsInput, OutDtSupportUndefined) | OutFmtSupportAsInput, OutDtSupportUndefined) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::FUSEDBATCHNORMGRAD, InFmtSupportUndefined, InDtSupportUndefined, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::FUSEDBATCHNORM, InFmtSupportUndefined, InDtSupportUndefined, OutFmtSupportAsInput, | |||||
| OutDtSupportUndefined) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::FUSEDBATCHNORMGRAD, InFmtSupportUndefined, InDtSupportUndefined, | |||||
| OutFmtSupportAsInput, OutDtSupportUndefined) | OutFmtSupportAsInput, OutDtSupportUndefined) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::CONV2D, InFmtSupportUndefined, InDtSupportUndefined, OutFmtSupportAsInput, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::CONV2D, InFmtSupportUndefined, InDtSupportUndefined, OutFmtSupportAsInput, | |||||
| OutDtSupportUndefined) | OutDtSupportUndefined) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::RESHAPE, ge::FORMAT_NHWC, InDtSupportAll, ge::FORMAT_NHWC, | |||||
| OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, InFmtSupport5D, ge::DT_FLOAT16, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::RESHAPE, ge::FORMAT_NHWC, InDtSupportAll, ge::FORMAT_NHWC, OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::SPARSESOFTMAXCROSSENTROPYWITHLOGITS, InFmtSupport5D, ge::DT_FLOAT16, | |||||
| OutFmtSupportAsInput, OutDtSupportAsInput) | OutFmtSupportAsInput, OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_MAXIMUM_GRAD, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_MAXIMUM_GRAD, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | ||||
| OutDtSupportAsInput) | OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::APPLYRMSPROP, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::APPLYRMSPROP, | |||||
| std::vector<ge::Format>({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NCHW, ge::FORMAT_NHWC, | std::vector<ge::Format>({ge::FORMAT_FRACTAL_Z, ge::FORMAT_NCHW, ge::FORMAT_NHWC, | ||||
| ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0}), | ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0}), | ||||
| ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) | ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::DROPOUTDOMASK, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::LOG, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SQRTGRAD, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SIGMOIDGRAD, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SIGMOID, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::ARGMAX, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::AVGPOOLGRAD, InFmtSupport5D, ge::DT_FLOAT16, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::NEG, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::RECIPROCAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::DROPOUTDOMASK, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SQUARE, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::LOG, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SQRTGRAD, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SIGMOIDGRAD, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SUB, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SIGMOID, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::ARGMAX, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::AVGPOOLGRAD, InFmtSupport5D, ge::DT_FLOAT16, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SUM, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::NEG, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::RECIPROCAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SQUARE, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SUB, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SUM, InFmtSupport4D, InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_MATMUL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) | TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_MATMUL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::GATHERV2, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::GATHERV2, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::GREATEREQUAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::GREATEREQUAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::REALDIV, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SQRT, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::STRIDEDSLICE, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::TILE, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::REALDIV, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SQRT, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::STRIDEDSLICE, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TILE, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TFRELU6, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, | TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TFRELU6, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, | ||||
| OutDtSupportAsInput) | OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::RELU6GRAD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::EQUAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::GREATER, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::SELECT, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::RELU6GRAD, InFmtSupportElewise, InDtSupportAll, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::EQUAL, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::GREATER, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::SELECT, InFmtSupport4D, ge::DT_FLOAT, OutFmtSupportAsInput, OutDtSupportAsInput) | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_BATCH_MATMUL, ge::FORMAT_NHWC, InDtSupportAll, OutFmtSupportAsInput, | TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TF_BATCH_MATMUL, ge::FORMAT_NHWC, InDtSupportAll, OutFmtSupportAsInput, | ||||
| OutDtSupportAsInput) | OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::parser::TRANSPOSE, ge::FORMAT_NHWC, InDtSupportAll, OutFmtSupportAsInput, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(TBE, ge::TRANSPOSE, ge::FORMAT_NHWC, InDtSupportAll, OutFmtSupportAsInput, | |||||
| OutDtSupportAsInput) | OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::STREAMMERGE, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::STREAMMERGE, | |||||
| std::vector<ge::Format>({ge::FORMAT_NCHW, ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0}), | std::vector<ge::Format>({ge::FORMAT_NCHW, ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0}), | ||||
| InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) | InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) | ||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::parser::MEMCPYASYNC, | |||||
| TBE_SET_FORMAT_DATAYPE_INFO(CCE, ge::MEMCPYASYNC, | |||||
| std::vector<ge::Format>({ge::FORMAT_NCHW, ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0}), | std::vector<ge::Format>({ge::FORMAT_NCHW, ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0}), | ||||
| InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) | InDtSupportAll, OutFmtSupportAsInput, OutDtSupportAsInput) | ||||
| bool GetCceTbeTransInfo(string opType, OpSupportTranInfo &opSupportInfo) { | bool GetCceTbeTransInfo(string opType, OpSupportTranInfo &opSupportInfo) { | ||||
| static bool fmtInited = false; | static bool fmtInited = false; | ||||
| GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
| !fmtInited, fmtInited = true; | |||||
| if (domi::OpRegistry().Instance()->GetImplyType(ge::parser::DEPTHWISEWEIGHT4D26D) == domi::ImplyType::TVM) { | |||||
| auto it = g_OpSupportTranInfo.find(string("TBE:") + ge::parser::MUL); | |||||
| if (it != g_OpSupportTranInfo.end()) { | |||||
| auto &fmts = it->second.inputFormats; | |||||
| auto itFmt = std::find(fmts.begin(), fmts.end(), ge::FORMAT_NC1HWC0); | |||||
| fmts.erase(itFmt); | |||||
| } | |||||
| }) | |||||
| !fmtInited, fmtInited = true; | |||||
| if (domi::OpRegistry().Instance()->GetImplyType(ge::DEPTHWISEWEIGHT4D26D) == domi::ImplyType::TVM) { | |||||
| auto it = g_OpSupportTranInfo.find(string("TBE:") + ge::MUL); | |||||
| if (it != g_OpSupportTranInfo.end()) { | |||||
| auto &fmts = it->second.inputFormats; | |||||
| auto itFmt = std::find(fmts.begin(), fmts.end(), ge::FORMAT_NC1HWC0); | |||||
| fmts.erase(itFmt); | |||||
| } | |||||
| }) | |||||
| string cceTbeOpType = "TBE"; | string cceTbeOpType = "TBE"; | ||||
| GE_IF_BOOL_EXEC(domi::OpRegistry().Instance()->GetImplyType(opType) == domi::ImplyType::BUILDIN, | GE_IF_BOOL_EXEC(domi::OpRegistry().Instance()->GetImplyType(opType) == domi::ImplyType::BUILDIN, | ||||
| cceTbeOpType = "CCE";) | cceTbeOpType = "CCE";) | ||||
| @@ -807,7 +786,7 @@ Status CreateNodeDefBytes(ge::NodePtr n, string originalType, map<string, PIOLis | |||||
| for (uint32_t j = 0; j < ge_desc->GetShape().GetDimNum(); ++j) { | for (uint32_t j = 0; j < ge_desc->GetShape().GetDimNum(); ++j) { | ||||
| tmp_dim = ge_desc->GetShape().GetDim(j); | tmp_dim = ge_desc->GetShape().GetDim(j); | ||||
| GE_CHECK_GE(tmp_dim, 0); | GE_CHECK_GE(tmp_dim, 0); | ||||
| PARSER_INT64_MULCHECK(real_size, tmp_dim); | |||||
| FMK_INT64_MULCHECK(real_size, tmp_dim); | |||||
| real_size *= tmp_dim; | real_size *= tmp_dim; | ||||
| } | } | ||||
| ge::TensorUtils::SetSize(*ge_desc, real_size * size_type); | ge::TensorUtils::SetSize(*ge_desc, real_size * size_type); | ||||
| @@ -1198,7 +1177,7 @@ Status CreateFuncDefBytes(ge::NodePtr n, string original_type, string func_bin_p | |||||
| char *buf = nullptr; | char *buf = nullptr; | ||||
| int32_t len = 0; | int32_t len = 0; | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::parser::ReadBytesFromBinaryFile(file.c_str(), &buf, len), return false, | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::ReadBytesFromBinaryFile(file.c_str(), &buf, len), return false, | |||||
| "read bytes file error!"); | "read bytes file error!"); | ||||
| GELOGI("len =%d\n", len); | GELOGI("len =%d\n", len); | ||||
| @@ -1229,7 +1208,7 @@ Status ParserGraphOptimizer::MakeTfProtoDef() { | |||||
| CreateIOListFuncMap(mOpIOListFuncMap); | CreateIOListFuncMap(mOpIOListFuncMap); | ||||
| for (ge::NodePtr n : graph_->GetDirectNode()) { | for (ge::NodePtr n : graph_->GetDirectNode()) { | ||||
| if (n->GetType() != ge::parser::FRAMEWORKOP) continue; | |||||
| if (n->GetType() != ge::FRAMEWORKOP) continue; | |||||
| std::string original_type; | std::string original_type; | ||||
| GE_LOGI_IF(ge::AttrUtils::GetStr(n->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type) != true, | GE_LOGI_IF(ge::AttrUtils::GetStr(n->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type) != true, | ||||
| "get original type failed."); | "get original type failed."); | ||||
| @@ -1290,9 +1269,9 @@ Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr> | |||||
| bool hasGetNext = false; | bool hasGetNext = false; | ||||
| for (auto node : graph_->GetDirectNode()) { | for (auto node : graph_->GetDirectNode()) { | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue); | |||||
| GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::FRAMEWORK_OP_TYPE, continue); | |||||
| string type = ""; | string type = ""; | ||||
| GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); | |||||
| GE_CHK_STATUS_RET(GetOriginalType(node, type)); | |||||
| if (type == "IteratorGetNext") { | if (type == "IteratorGetNext") { | ||||
| hasGetNext = true; | hasGetNext = true; | ||||
| break; | break; | ||||
| @@ -1300,9 +1279,9 @@ Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr> | |||||
| } | } | ||||
| for (auto node : graph_->GetDirectNode()) { | for (auto node : graph_->GetDirectNode()) { | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue) | |||||
| GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::FRAMEWORK_OP_TYPE, continue) | |||||
| string type = ""; | string type = ""; | ||||
| GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type)); | |||||
| GE_CHK_STATUS_RET(GetOriginalType(node, type)); | |||||
| if (type == "IteratorGetNext") { | if (type == "IteratorGetNext") { | ||||
| vector<NodePtr> temp_node_cluser; | vector<NodePtr> temp_node_cluser; | ||||
| for (auto in_anchor : node->GetAllInDataAnchors()) { | for (auto in_anchor : node->GetAllInDataAnchors()) { | ||||
| @@ -1338,9 +1317,9 @@ Status ParserGraphOptimizer::FindFmkNodeCluser(unordered_map<string, vector<Node | |||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| OpDescPtr temp_node_desc_ptr = node->GetOpDesc(); | OpDescPtr temp_node_desc_ptr = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(temp_node_desc_ptr); | GE_CHECK_NOTNULL(temp_node_desc_ptr); | ||||
| GE_IF_BOOL_EXEC(temp_node_desc_ptr->GetType() == ge::parser::DATA_TYPE, continue); | |||||
| GE_IF_BOOL_EXEC(temp_node_desc_ptr->GetType() == ge::DATA_TYPE, continue); | |||||
| if (temp_node_desc_ptr->GetType() == ge::parser::FRAMEWORK_OP_TYPE && | |||||
| if (temp_node_desc_ptr->GetType() == ge::FRAMEWORK_OP_TYPE && | |||||
| (temp_node_desc_ptr->GetName().find(RRTVAL_NODE_NAME_SUFFIX) == string::npos)) { | (temp_node_desc_ptr->GetName().find(RRTVAL_NODE_NAME_SUFFIX) == string::npos)) { | ||||
| temp_node_cluser.push_back(node); | temp_node_cluser.push_back(node); | ||||
| } else { | } else { | ||||
| @@ -1421,7 +1400,7 @@ Status ParserGraphOptimizer::UpdateGraph(vector<NodePtr> &nodes) { | |||||
| return FAILED); | return FAILED); | ||||
| std::string type = ""; | std::string type = ""; | ||||
| GE_CHK_STATUS_RET(ge::parser::GetOriginalType(nodes[0], type)); | |||||
| GE_CHK_STATUS_RET(GetOriginalType(nodes[0], type)); | |||||
| (void)AttrUtils::SetStr(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); | (void)AttrUtils::SetStr(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); | ||||
| (void)AttrUtils::SetZeroCopyBytes( | (void)AttrUtils::SetZeroCopyBytes( | ||||
| @@ -1431,7 +1410,7 @@ Status ParserGraphOptimizer::UpdateGraph(vector<NodePtr> &nodes) { | |||||
| fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_NODE_DEF, | fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_NODE_DEF, | ||||
| Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(nodefStr.data()), nodefStr.length())); | Buffer::CopyFrom(reinterpret_cast<const uint8_t *>(nodefStr.data()), nodefStr.length())); | ||||
| (void)AttrUtils::SetInt(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, ge::GetParserContext().type); | |||||
| (void)AttrUtils::SetInt(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, domi::GetContext().type); | |||||
| // reconstruct fusion_node and edges | // reconstruct fusion_node and edges | ||||
| GE_CHK_STATUS_RET(RebuildOutputAnchors(output_anchors, fusion_node_opdef), | GE_CHK_STATUS_RET(RebuildOutputAnchors(output_anchors, fusion_node_opdef), | ||||
| @@ -1481,19 +1460,17 @@ Status ParserGraphOptimizer::InsertNode(ge::ComputeGraphPtr sub_graph, vector<ge | |||||
| } | } | ||||
| InControlAnchorPtr node_in_control = node->GetInControlAnchor(); | InControlAnchorPtr node_in_control = node->GetInControlAnchor(); | ||||
| GE_IF_BOOL_EXEC( | |||||
| node_in_control != nullptr, for (auto peer_out_anchor | |||||
| : node_in_control->GetPeerOutControlAnchors()) { | |||||
| vector<ge::NodePtr>::iterator iter = find(nodes.begin(), nodes.end(), peer_out_anchor->GetOwnerNode()); | |||||
| GE_IF_BOOL_EXEC(iter == nodes.end(), input_control_anchors.emplace_back(node_in_control)); | |||||
| }); | |||||
| GE_IF_BOOL_EXEC(node_in_control != nullptr, for (auto peer_out_anchor | |||||
| : node_in_control->GetPeerOutControlAnchors()) { | |||||
| vector<ge::NodePtr>::iterator iter = find(nodes.begin(), nodes.end(), peer_out_anchor->GetOwnerNode()); | |||||
| GE_IF_BOOL_EXEC(iter == nodes.end(), input_control_anchors.emplace_back(node_in_control)); | |||||
| }); | |||||
| OutControlAnchorPtr node_out_control = node->GetOutControlAnchor(); | OutControlAnchorPtr node_out_control = node->GetOutControlAnchor(); | ||||
| GE_IF_BOOL_EXEC( | |||||
| node_out_control != nullptr, for (auto peer_in_control_anchor | |||||
| : node_out_control->GetPeerInControlAnchors()) { | |||||
| vector<ge::NodePtr>::iterator iter = find(nodes.begin(), nodes.end(), peer_in_control_anchor->GetOwnerNode()); | |||||
| GE_IF_BOOL_EXEC(iter == nodes.end(), output_control_anchors.emplace_back(node_out_control)); | |||||
| }); | |||||
| GE_IF_BOOL_EXEC(node_out_control != nullptr, for (auto peer_in_control_anchor | |||||
| : node_out_control->GetPeerInControlAnchors()) { | |||||
| vector<ge::NodePtr>::iterator iter = find(nodes.begin(), nodes.end(), peer_in_control_anchor->GetOwnerNode()); | |||||
| GE_IF_BOOL_EXEC(iter == nodes.end(), output_control_anchors.emplace_back(node_out_control)); | |||||
| }); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -1518,19 +1495,18 @@ Status ParserGraphOptimizer::LinkInnerAnchor(unordered_map<string, ge::NodePtr> | |||||
| } | } | ||||
| InControlAnchorPtr node_in_control = node->GetInControlAnchor(); | InControlAnchorPtr node_in_control = node->GetInControlAnchor(); | ||||
| GE_IF_BOOL_EXEC( | |||||
| node_in_control != nullptr, for (auto peer_out_ctl_anchor | |||||
| : node_in_control->GetPeerOutControlAnchors()) { | |||||
| GE_IF_BOOL_EXEC(node_map.count(peer_out_ctl_anchor->GetOwnerNode()->GetName()) == 0, continue); | |||||
| NodePtr src_ctrl = node_map[peer_out_ctl_anchor->GetOwnerNode()->GetName()]; | |||||
| GE_IF_BOOL_EXEC( | |||||
| ge::GraphUtils::AddEdge(src_ctrl->GetOutControlAnchor(), dst->GetInControlAnchor()) != GRAPH_SUCCESS, | |||||
| GELOGE(FAILED, | |||||
| "LinkInnerAnchor Link control anchor failed, src node: " | |||||
| "%s, dst node: %s.", | |||||
| src_ctrl->GetName().c_str(), dst->GetName().c_str()); | |||||
| return FAILED); | |||||
| }); | |||||
| GE_IF_BOOL_EXEC(node_in_control != nullptr, for (auto peer_out_ctl_anchor | |||||
| : node_in_control->GetPeerOutControlAnchors()) { | |||||
| GE_IF_BOOL_EXEC(node_map.count(peer_out_ctl_anchor->GetOwnerNode()->GetName()) == 0, continue); | |||||
| NodePtr src_ctrl = node_map[peer_out_ctl_anchor->GetOwnerNode()->GetName()]; | |||||
| GE_IF_BOOL_EXEC( | |||||
| ge::GraphUtils::AddEdge(src_ctrl->GetOutControlAnchor(), dst->GetInControlAnchor()) != GRAPH_SUCCESS, | |||||
| GELOGE(FAILED, | |||||
| "LinkInnerAnchor Link control anchor failed, src node: " | |||||
| "%s, dst node: %s.", | |||||
| src_ctrl->GetName().c_str(), dst->GetName().c_str()); | |||||
| return FAILED); | |||||
| }); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -1881,24 +1857,24 @@ OpDescPtr ParserGraphOptimizer::CreateTranslateOp(enum ge::Format inFormat, enum | |||||
| static uint32_t transop_count = 0; | static uint32_t transop_count = 0; | ||||
| OpDescPtr op_def = nullptr; | OpDescPtr op_def = nullptr; | ||||
| std::stringstream sstmp; | std::stringstream sstmp; | ||||
| sstmp << "translate_" << ge::parser::TRANSDATA << "_" << transop_count++; | |||||
| GE_MAKE_SHARED(op_def = std::make_shared<OpDesc>(sstmp.str().c_str(), ge::parser::TRANSLATE), op_def = nullptr; | |||||
| sstmp << "translate_" << ge::TRANSDATA << "_" << transop_count++; | |||||
| GE_MAKE_SHARED(op_def = std::make_shared<OpDesc>(sstmp.str().c_str(), ge::TRANSLATE), op_def = nullptr; | |||||
| return op_def); | return op_def); | ||||
| GELOGI( | GELOGI( | ||||
| "create translate op:%s, input format:%s, input datatype:%s, output " | |||||
| "format:%s, output datatype:%s.", | |||||
| op_def->GetName().c_str(), ge::TypeUtils::FormatToSerialString(inFormat).c_str(), | |||||
| ge::TypeUtils::DataTypeToSerialString(inDatatype).c_str(), ge::TypeUtils::FormatToSerialString(outFormat).c_str(), | |||||
| ge::TypeUtils::DataTypeToSerialString(outDatatype).c_str()); | |||||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ge::ATTR_NAME_INPUT_FORMAT, inFormat), return nullptr, | |||||
| "SetInt ATTR_NAME_INPUT_FORMAT failed."); | |||||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ATTR_NAME_INPUT_DATATYPE, inDatatype), return nullptr, | |||||
| "SetInt ATTR_NAME_INPUT_DATATYPE failed."); | |||||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ge::ATTR_NAME_OUTPUT_FORMAT, outFormat), return nullptr, | |||||
| "SetInt ATTR_NAME_INPUT_DATATYPE failed."); | |||||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ATTR_NAME_OUTPUT_DATATYPE, outDatatype), return nullptr, | |||||
| "SetInt ATTR_NAME_INPUT_DATATYPE failed."); | |||||
| "create translate op:%s, input format:%s, input datatype:%s, output " | |||||
| "format:%s, output datatype:%s.", | |||||
| op_def->GetName().c_str(), ge::TypeUtils::FormatToSerialString(inFormat).c_str(), | |||||
| ge::TypeUtils::DataTypeToSerialString(inDatatype).c_str(), ge::TypeUtils::FormatToSerialString(outFormat).c_str(), | |||||
| ge::TypeUtils::DataTypeToSerialString(outDatatype).c_str()); | |||||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ge::ATTR_NAME_INPUT_FORMAT, inFormat), | |||||
| return nullptr, "SetInt ATTR_NAME_INPUT_FORMAT failed."); | |||||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ATTR_NAME_INPUT_DATATYPE, inDatatype), | |||||
| return nullptr, "SetInt ATTR_NAME_INPUT_DATATYPE failed."); | |||||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ge::ATTR_NAME_OUTPUT_FORMAT, outFormat), | |||||
| return nullptr, "SetInt ATTR_NAME_INPUT_DATATYPE failed."); | |||||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_def, ATTR_NAME_OUTPUT_DATATYPE, outDatatype), | |||||
| return nullptr, "SetInt ATTR_NAME_INPUT_DATATYPE failed."); | |||||
| if (inDatatype != ge::DT_FLOAT16) { | if (inDatatype != ge::DT_FLOAT16) { | ||||
| GE_CHK_BOOL_EXEC(SUCCESS == op_def->AddInputDesc(GeTensorDesc(GeShape(), inFormat)), return nullptr, | GE_CHK_BOOL_EXEC(SUCCESS == op_def->AddInputDesc(GeTensorDesc(GeShape(), inFormat)), return nullptr, | ||||
| "create translate op:add input desc fail."); | "create translate op:add input desc fail."); | ||||
| @@ -1920,17 +1896,17 @@ OpDescPtr ParserGraphOptimizer::CreatePermuteOp(enum ge::Format input_format, en | |||||
| static uint32_t transop_count = 0; | static uint32_t transop_count = 0; | ||||
| std::stringstream sstmp; | std::stringstream sstmp; | ||||
| sstmp << "transdata_" << ge::parser::PERMUTE << "_" << transop_count++; | |||||
| sstmp << "transdata_" << ge::PERMUTE << "_" << transop_count++; | |||||
| OpDescPtr op_desc = nullptr; | OpDescPtr op_desc = nullptr; | ||||
| GE_MAKE_SHARED(op_desc = std::make_shared<OpDesc>(sstmp.str().c_str(), ge::parser::PERMUTE), op_desc = nullptr; | |||||
| GE_MAKE_SHARED(op_desc = std::make_shared<OpDesc>(sstmp.str().c_str(), ge::PERMUTE), op_desc = nullptr; | |||||
| return op_desc); | return op_desc); | ||||
| GELOGI("create permute op:%s", op_desc->GetName().c_str()); | GELOGI("create permute op:%s", op_desc->GetName().c_str()); | ||||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_INPUT_FORMAT, (int64_t)input_format), return nullptr, | |||||
| "SetInt ATTR_NAME_INPUT_FORMAT failed."); | |||||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_OUTPUT_FORMAT, (int64_t)output_format), return nullptr, | |||||
| "SetInt ATTR_NAME_OUTPUT_FORMAT failed."); | |||||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_INPUT_FORMAT, (int64_t)input_format), | |||||
| return nullptr, "SetInt ATTR_NAME_INPUT_FORMAT failed."); | |||||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_OUTPUT_FORMAT, (int64_t)output_format), | |||||
| return nullptr, "SetInt ATTR_NAME_OUTPUT_FORMAT failed."); | |||||
| GE_IF_BOOL_EXEC(input_format == FORMAT_NCHW, (void)AttrUtils::SetInt(op_desc, "NCHW_to_NHWC", (int64_t)1)); | GE_IF_BOOL_EXEC(input_format == FORMAT_NCHW, (void)AttrUtils::SetInt(op_desc, "NCHW_to_NHWC", (int64_t)1)); | ||||
| GE_IF_BOOL_EXEC(input_format == FORMAT_NHWC, (void)AttrUtils::SetInt(op_desc, "NHWC_to_NCHW", (int64_t)1)); | GE_IF_BOOL_EXEC(input_format == FORMAT_NHWC, (void)AttrUtils::SetInt(op_desc, "NHWC_to_NCHW", (int64_t)1)); | ||||
| @@ -1947,11 +1923,10 @@ OpDescPtr ParserGraphOptimizer::CreateCastOp(enum ge::DataType input_data_type, | |||||
| enum ge::Format format) { | enum ge::Format format) { | ||||
| static uint32_t transop_count = 0; | static uint32_t transop_count = 0; | ||||
| std::stringstream sstmp; | std::stringstream sstmp; | ||||
| sstmp << "transdata_" << ge::parser::CAST << "_" << transop_count++; | |||||
| sstmp << "transdata_" << ge::CAST << "_" << transop_count++; | |||||
| OpDescPtr op_desc = nullptr; | OpDescPtr op_desc = nullptr; | ||||
| GE_MAKE_SHARED(op_desc = std::make_shared<OpDesc>(sstmp.str().c_str(), ge::parser::CAST), op_desc = nullptr; | |||||
| return op_desc); | |||||
| GE_MAKE_SHARED(op_desc = std::make_shared<OpDesc>(sstmp.str().c_str(), ge::CAST), op_desc = nullptr; return op_desc); | |||||
| GELOGI("create cast op:%s, input datatype:%s, out datatype:%s", op_desc->GetName().c_str(), | GELOGI("create cast op:%s, input datatype:%s, out datatype:%s", op_desc->GetName().c_str(), | ||||
| ge::TypeUtils::DataTypeToSerialString(input_data_type).c_str(), | ge::TypeUtils::DataTypeToSerialString(input_data_type).c_str(), | ||||
| ge::TypeUtils::DataTypeToSerialString(output_data_type).c_str()); | ge::TypeUtils::DataTypeToSerialString(output_data_type).c_str()); | ||||
| @@ -1975,10 +1950,10 @@ OpDescPtr ParserGraphOptimizer::CreateCastOp(enum ge::DataType input_data_type, | |||||
| OpDescPtr ParserGraphOptimizer::CreateTransDataOp(enum ge::Format input_format) { | OpDescPtr ParserGraphOptimizer::CreateTransDataOp(enum ge::Format input_format) { | ||||
| static uint32_t transop_count = 0; | static uint32_t transop_count = 0; | ||||
| std::stringstream sstmp; | std::stringstream sstmp; | ||||
| sstmp << "transdata_" << ge::parser::TRANSDATA << "_" << transop_count++; | |||||
| sstmp << "transdata_" << ge::TRANSDATA << "_" << transop_count++; | |||||
| OpDescPtr op_desc = nullptr; | OpDescPtr op_desc = nullptr; | ||||
| GE_MAKE_SHARED(op_desc = std::make_shared<OpDesc>(sstmp.str().c_str(), ge::parser::TRANSDATA), op_desc = nullptr; | |||||
| GE_MAKE_SHARED(op_desc = std::make_shared<OpDesc>(sstmp.str().c_str(), ge::TRANSDATA), op_desc = nullptr; | |||||
| return op_desc); | return op_desc); | ||||
| GELOGI("create transdata op:%s, input format:%s.", op_desc->GetName().c_str(), | GELOGI("create transdata op:%s, input format:%s.", op_desc->GetName().c_str(), | ||||
| @@ -1989,10 +1964,10 @@ OpDescPtr ParserGraphOptimizer::CreateTransDataOp(enum ge::Format input_format) | |||||
| output_format = FORMAT_NCHW; | output_format = FORMAT_NCHW; | ||||
| } | } | ||||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_INPUT_FORMAT, (int64_t)input_format), return nullptr, | |||||
| "SetInt of ATTR_NAME_INPUT_FORMAT failed."); | |||||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_OUTPUT_FORMAT, (int64_t)output_format), return nullptr, | |||||
| "SetInt of ATTR_NAME_OUTPUT_FORMAT failed."); | |||||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_INPUT_FORMAT, (int64_t)input_format), | |||||
| return nullptr, "SetInt of ATTR_NAME_INPUT_FORMAT failed."); | |||||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ge::ATTR_NAME_OUTPUT_FORMAT, (int64_t)output_format), | |||||
| return nullptr, "SetInt of ATTR_NAME_OUTPUT_FORMAT failed."); | |||||
| GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddInputDesc(GeTensorDesc(GeShape(), input_format)), return nullptr, | GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddInputDesc(GeTensorDesc(GeShape(), input_format)), return nullptr, | ||||
| "create transdata op:add input desc fail."); | "create transdata op:add input desc fail."); | ||||
| GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddOutputDesc(GeTensorDesc(GeShape(), output_format)), return nullptr, | GE_CHK_BOOL_EXEC(SUCCESS == op_desc->AddOutputDesc(GeTensorDesc(GeShape(), output_format)), return nullptr, | ||||
| @@ -2000,4 +1975,4 @@ OpDescPtr ParserGraphOptimizer::CreateTransDataOp(enum ge::Format input_format) | |||||
| return op_desc; | return op_desc; | ||||
| } | } | ||||
| } // namespace ge | |||||
| } // namespace domi | |||||
| @@ -20,7 +20,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <vector> | #include <vector> | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "common/types.h" | |||||
| #include "graph/anchor.h" | #include "graph/anchor.h" | ||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "graph/node.h" | #include "graph/node.h" | ||||
| @@ -46,9 +46,8 @@ class ParserGraphOptimizer { | |||||
| domi::Status FusionFmkop(); | domi::Status FusionFmkop(); | ||||
| inline bool IsHCOMOp(const string &op_type) { | inline bool IsHCOMOp(const string &op_type) { | ||||
| return (op_type == ge::parser::HCOMALLREDUCE) || (op_type == ge::parser::HCOMALLGATHER) || | |||||
| (op_type == ge::parser::HCOMBROADCAST) || (op_type == ge::parser::HCOMSEND) || | |||||
| (op_type == ge::parser::HCOMRECEIVE) || (op_type == "HcomReduceScatter"); | |||||
| return (op_type == ge::HCOMALLREDUCE) || (op_type == ge::HCOMALLGATHER) || (op_type == ge::HCOMBROADCAST) || | |||||
| (op_type == ge::HCOMSEND) || (op_type == ge::HCOMRECEIVE) || (op_type == "HcomReduceScatter"); | |||||
| } | } | ||||
| void SetLocalFmkopFlag(bool isLocalFmkopFlag) { local_fmk_op_flag_ = isLocalFmkopFlag; } | void SetLocalFmkopFlag(bool isLocalFmkopFlag) { local_fmk_op_flag_ = isLocalFmkopFlag; } | ||||
| @@ -104,11 +103,11 @@ class ParserGraphOptimizer { | |||||
| domi::Status UpdateGraph(vector<ge::NodePtr> &nodes); | domi::Status UpdateGraph(vector<ge::NodePtr> &nodes); | ||||
| domi::Status InsertNode(ge::ComputeGraphPtr sub_graph, vector<ge::NodePtr> &nodes, | domi::Status InsertNode(ge::ComputeGraphPtr sub_graph, vector<ge::NodePtr> &nodes, | ||||
| vector<ge::InDataAnchorPtr> &input_anchors, vector<ge::OutDataAnchorPtr> &output_anchors, | |||||
| map<ge::OutDataAnchorPtr, vector<ge::InDataAnchorPtr>> &output_in_map, | |||||
| vector<ge::InControlAnchorPtr> &input_control_anchors, | |||||
| vector<ge::OutControlAnchorPtr> &output_control_anchors, | |||||
| unordered_map<string, ge::NodePtr> &node_map); | |||||
| vector<ge::InDataAnchorPtr> &input_anchors, vector<ge::OutDataAnchorPtr> &output_anchors, | |||||
| map<ge::OutDataAnchorPtr, vector<ge::InDataAnchorPtr>> &output_in_map, | |||||
| vector<ge::InControlAnchorPtr> &input_control_anchors, | |||||
| vector<ge::OutControlAnchorPtr> &output_control_anchors, | |||||
| unordered_map<string, ge::NodePtr> &node_map); | |||||
| domi::Status LinkInnerAnchor(unordered_map<string, ge::NodePtr> &node_map); | domi::Status LinkInnerAnchor(unordered_map<string, ge::NodePtr> &node_map); | ||||
| @@ -124,5 +123,5 @@ class ParserGraphOptimizer { | |||||
| domi::Status MakeTfProtoDef(); | domi::Status MakeTfProtoDef(); | ||||
| }; | }; | ||||
| } // namespace ge | |||||
| } // namespace domi | |||||
| #endif // GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZER_H_ | #endif // GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZER_H_ | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "common/types.h" | |||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "graph_optimizer.h" | #include "graph_optimizer.h" | ||||
| #include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
| @@ -45,9 +45,6 @@ message AippOpParams { | |||||
| // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | ||||
| uint32 related_input_rank = 2; | uint32 related_input_rank = 2; | ||||
| // related_input_name is optional and the top name of data node which inserts aipp | |||||
| string related_input_name = 6; | |||||
| // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | ||||
| // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | ||||
| // 配置值 <= Data算子输出边的个数。 | // 配置值 <= Data算子输出边的个数。 | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #include "parser/tensorflow/scope/scope_pass_manager.h" | #include "parser/tensorflow/scope/scope_pass_manager.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "common/util/error_manager/error_manager.h" | #include "common/util/error_manager/error_manager.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| @@ -25,7 +25,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| shared_ptr<ScopeGraph> ScopePassManager::BuildScopeGraph(domi::tensorflow::GraphDef *graph_def) { | shared_ptr<ScopeGraph> ScopePassManager::BuildScopeGraph(domi::tensorflow::GraphDef *graph_def) { | ||||
| GE_CHK_BOOL_EXEC(graph_def != nullptr, return nullptr, "graph_def is nullptr"); | GE_CHK_BOOL_EXEC(graph_def != nullptr, return nullptr, "graph_def is nullptr"); | ||||
| scope_graph_ = ge::parser::MakeShared<ScopeGraph>(); | |||||
| scope_graph_ = ge::MakeShared<ScopeGraph>(); | |||||
| if (scope_graph_ == nullptr) { | if (scope_graph_ == nullptr) { | ||||
| GELOGE(FAILED, "Scope graph make shared failed."); | GELOGE(FAILED, "Scope graph make shared failed."); | ||||
| return nullptr; | return nullptr; | ||||
| @@ -17,7 +17,6 @@ | |||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "parser/common/op_def/arg_op.h" | #include "parser/common/op_def/arg_op.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/omg/parser/parser_inner_ctx.h" | |||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| @@ -45,7 +44,7 @@ Status ParseParams(const Message *op_src, ArgOpOperator *op) { | |||||
| "trans output_attr_value failed, op: %s", node->name().c_str()); | "trans output_attr_value failed, op: %s", node->name().c_str()); | ||||
| domi::tensorflow::AttrValue_ListValue attr_list = output_attr_value.list(); | domi::tensorflow::AttrValue_ListValue attr_list = output_attr_value.list(); | ||||
| GetParserContext().format = | |||||
| domi::GetContext().format = | |||||
| static_cast<domi::tagDomiTensorFormat>(attr_list.func(0).attr().at(kSerializeFormat).i()); | static_cast<domi::tagDomiTensorFormat>(attr_list.func(0).attr().at(kSerializeFormat).i()); | ||||
| } else { | } else { | ||||
| /// _Arg constructed from inference function do not has input_tensor_dec | /// _Arg constructed from inference function do not has input_tensor_dec | ||||
| @@ -65,5 +64,5 @@ Status ParseParams(const Message *op_src, ArgOpOperator *op) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| DOMI_REGISTER_TENSORFLOW_PARSER(ge::parser::ARG, ArgOpOperator).SetParseParamsFn(ParseParams); | |||||
| DOMI_REGISTER_TENSORFLOW_PARSER(ge::ARG, ArgOpOperator).SetParseParamsFn(ParseParams); | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -16,7 +16,6 @@ | |||||
| #include "tensorflow_auto_mapping_parser_adapter.h" | #include "tensorflow_auto_mapping_parser_adapter.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| @@ -25,9 +24,6 @@ | |||||
| using domi::TENSORFLOW; | using domi::TENSORFLOW; | ||||
| using namespace ge::parser; | |||||
| using ge::parser::PLACEHOLDERWITHDEFAULT; | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
| #include "parser/common/op_def/constant_op.h" | #include "parser/common/op_def/constant_op.h" | ||||
| #include "parser/common/op_def/ir_pb_converter.h" | #include "parser/common/op_def/ir_pb_converter.h" | ||||
| @@ -27,12 +27,10 @@ | |||||
| #include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
| #include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "register/tensor_assign.h" | #include "register/tensor_assign.h" | ||||
| using domi::tensorflow::NodeDef; | using domi::tensorflow::NodeDef; | ||||
| using domi::TENSORFLOW; | using domi::TENSORFLOW; | ||||
| using ge::parser::CONSTANTOP; | |||||
| namespace ge { | namespace ge { | ||||
| Status TensorFlowConstantParser::ParseDType(const domi::tensorflow::NodeDef *node, ConstantOperator *op) { | Status TensorFlowConstantParser::ParseDType(const domi::tensorflow::NodeDef *node, ConstantOperator *op) { | ||||
| @@ -68,7 +66,7 @@ Status TensorFlowConstantParser::ParseValue(const domi::tensorflow::NodeDef *nod | |||||
| const domi::tensorflow::TensorProto &tensor = attr_value.tensor(); | const domi::tensorflow::TensorProto &tensor = attr_value.tensor(); | ||||
| GeTensorPtr weight = ge::parser::MakeShared<ge::GeTensor>(); | |||||
| GeTensorPtr weight = ge::MakeShared<ge::GeTensor>(); | |||||
| GE_CHECK_NOTNULL(weight); | GE_CHECK_NOTNULL(weight); | ||||
| int64_t dataType = 0; | int64_t dataType = 0; | ||||
| GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(opDesc, TENSORFLOW_ATTR_DTYPE, dataType), INTERNAL_ERROR, | GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(opDesc, TENSORFLOW_ATTR_DTYPE, dataType), INTERNAL_ERROR, | ||||
| @@ -19,14 +19,11 @@ | |||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/omg/parser/parser_inner_ctx.h" | |||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| using domi::tensorflow::AttrValue; | using domi::tensorflow::AttrValue; | ||||
| using domi::tensorflow::NodeDef; | using domi::tensorflow::NodeDef; | ||||
| using domi::TENSORFLOW; | using domi::TENSORFLOW; | ||||
| using ge::parser::DATA; | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| @@ -100,7 +97,7 @@ Status TensorFlowDataParser::ParseInputFromModel(const Message *op_src, ge::OpDe | |||||
| Status TensorFlowDataParser::ParseInputFromUser(const Message *op_src, const ge::OpDescPtr &op_def) { | Status TensorFlowDataParser::ParseInputFromUser(const Message *op_src, const ge::OpDescPtr &op_def) { | ||||
| GE_CHECK_NOTNULL(op_def); | GE_CHECK_NOTNULL(op_def); | ||||
| (void)op_src; | (void)op_src; | ||||
| const ge::ParserContext &ctx = GetParserContext(); | |||||
| const ge::OmgContext &ctx = domi::GetContext(); | |||||
| std::unordered_map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | std::unordered_map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; | ||||
| // User not designate the input_shape | // User not designate the input_shape | ||||
| std::string name = op_def->GetName(); | std::string name = op_def->GetName(); | ||||
| @@ -134,7 +131,7 @@ Status TensorFlowDataParser::ParseInputFromUser(const Message *op_src, const ge: | |||||
| } | } | ||||
| Status TensorFlowDataParser::CheckInputShape(const std::string &name) { | Status TensorFlowDataParser::CheckInputShape(const std::string &name) { | ||||
| const ge::ParserContext &ctx = GetParserContext(); | |||||
| const ge::OmgContext &ctx = domi::GetContext(); | |||||
| if (!ctx.is_dynamic_input) { | if (!ctx.is_dynamic_input) { | ||||
| for (uint32_t i = 0; i < user_input_dims_v.size(); i++) { | for (uint32_t i = 0; i < user_input_dims_v.size(); i++) { | ||||
| // if input_shape has some placeholders, user should designate them. | // if input_shape has some placeholders, user should designate them. | ||||
| @@ -19,11 +19,8 @@ | |||||
| #include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| using domi::TENSORFLOW; | using domi::TENSORFLOW; | ||||
| using ge::parser::ENTER; | |||||
| using ge::parser::REFENTER; | |||||
| namespace ge { | namespace ge { | ||||
| Status TensorFlowEnterParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) { | Status TensorFlowEnterParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) { | ||||
| @@ -20,11 +20,6 @@ | |||||
| #include "parser/common/op_def/fill_op.h" | #include "parser/common/op_def/fill_op.h" | ||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "parser/tensorflow/tensorflow_parser_register.h" | #include "parser/tensorflow/tensorflow_parser_register.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| using ge::parser::ALPHA_DEFAULT_VALUE; | |||||
| using ge::parser::BETA_DEFAULT_VALUE; | |||||
| using ge::parser::FILL; | |||||
| namespace ge { | namespace ge { | ||||
| /* | /* | ||||
| @@ -58,8 +53,8 @@ domi::Status ParseParams(const NodeDef *node, FillOperator *op) { | |||||
| op->DataType(type); | op->DataType(type); | ||||
| op->Alpha(ge::parser::ALPHA_DEFAULT_VALUE); | |||||
| op->Beta(ge::parser::BETA_DEFAULT_VALUE); | |||||
| op->Alpha(ge::ALPHA_DEFAULT_VALUE); | |||||
| op->Beta(ge::BETA_DEFAULT_VALUE); | |||||
| return domi::SUCCESS; | return domi::SUCCESS; | ||||
| } | } | ||||
| @@ -18,15 +18,14 @@ | |||||
| #include "parser/common/op_def/frameworkop_op.h" | #include "parser/common/op_def/frameworkop_op.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "parser/tensorflow/tensorflow_op_parser.h" | #include "parser/tensorflow/tensorflow_op_parser.h" | ||||
| #include "parser/tensorflow/tensorflow_parser_register.h" | #include "parser/tensorflow/tensorflow_parser_register.h" | ||||
| #include "proto/tensorflow/tensor_shape.pb.h" | #include "proto/tensorflow/tensor_shape.pb.h" | ||||
| using domi::tensorflow::TensorShapeProto; | using domi::tensorflow::TensorShapeProto; | ||||
| using domi::tensorflow::AttrValue; | using domi::tensorflow::AttrValue; | ||||
| using ge::FRAMEWORKOP; | |||||
| using domi::TENSORFLOW; | using domi::TENSORFLOW; | ||||
| using ge::parser::FRAMEWORKOP; | |||||
| namespace ge { | namespace ge { | ||||
| Status ParseParams(const Message *op_src, FrameworkOpOperator *op) { | Status ParseParams(const Message *op_src, FrameworkOpOperator *op) { | ||||
| @@ -17,11 +17,11 @@ | |||||
| #include "parser/tensorflow/tensorflow_fusion_op_parser.h" | #include "parser/tensorflow/tensorflow_fusion_op_parser.h" | ||||
| #include <memory> | #include <memory> | ||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "common/fp16_t.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "omg/omg.h" | #include "omg/omg.h" | ||||
| #include "parser/common/parser_fp16_t.h" | |||||
| #include "parser/tensorflow/tensorflow_op_parser.h" | #include "parser/tensorflow/tensorflow_op_parser.h" | ||||
| #include "register/tensor_assign.h" | #include "register/tensor_assign.h" | ||||
| @@ -115,7 +115,7 @@ Status TensorFlowFusionOpParser::ParseHalfFromConst(const NodeDef *node_def, flo | |||||
| auto val_vec = tensor.half_val(); | auto val_vec = tensor.half_val(); | ||||
| int32_t val_size = val_vec.size(); | int32_t val_size = val_vec.size(); | ||||
| if (index < val_size) { | if (index < val_size) { | ||||
| ge::parser::fp16_t fp16_value = static_cast<parser::fp16_t>(val_vec.Get(index)); | |||||
| fp16_t fp16_value = static_cast<fp16_t>(val_vec.Get(index)); | |||||
| param = fp16_value.ToFloat(); | param = fp16_value.ToFloat(); | ||||
| } else { | } else { | ||||
| GELOGE(domi::PARAM_INVALID, "Const data size is smaller than index:%d, not supported.", index); | GELOGE(domi::PARAM_INVALID, "Const data size is smaller than index:%d, not supported.", index); | ||||
| @@ -132,7 +132,7 @@ Status TensorFlowFusionOpParser::ParseWeightFromConst(const NodeDef *node_def, g | |||||
| GE_CHECK_NOTNULL(node_def); | GE_CHECK_NOTNULL(node_def); | ||||
| TensorProto tensor; | TensorProto tensor; | ||||
| GE_CHK_STATUS_RET(GetTensorFromNode(node_def, tensor), "get tensor failed."); | GE_CHK_STATUS_RET(GetTensorFromNode(node_def, tensor), "get tensor failed."); | ||||
| weight = ge::parser::MakeShared<ge::GeTensor>(); | |||||
| weight = ge::MakeShared<ge::GeTensor>(); | |||||
| GE_CHECK_NOTNULL(weight); | GE_CHECK_NOTNULL(weight); | ||||
| domi::tensorflow::DataType data_type = tensor.dtype(); | domi::tensorflow::DataType data_type = tensor.dtype(); | ||||
| GE_CHK_STATUS_RET( | GE_CHK_STATUS_RET( | ||||
| @@ -20,7 +20,6 @@ | |||||
| #include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "parser/tensorflow/tensorflow_parser.h" | #include "parser/tensorflow/tensorflow_parser.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include <iostream> | #include <iostream> | ||||
| #include <cstdlib> | #include <cstdlib> | ||||
| @@ -114,21 +113,21 @@ static map<string, string> tensorflow_fusionop_map = { | |||||
| // <Types of fusion operators, Number of children operators> | // <Types of fusion operators, Number of children operators> | ||||
| static map<string, vector<int>> tensorflow_fusionop_children_nums_map = { | static map<string, vector<int>> tensorflow_fusionop_children_nums_map = { | ||||
| {ge::parser::CLIPBOXES, {8}}, | |||||
| {ge::parser::FASTRCNNPREDICTIONS, {118, 119, 120, 123, 125}}, | |||||
| {ge::parser::RPNPROPOSALS, {75, 85, 97}}, | |||||
| {ge::parser::DECODEBBOX, {24, 28}}, | |||||
| {ge::parser::ROIALIGN, {82, 83, 84}}, | |||||
| {ge::parser::FUSIONBATCHNORM, {8}}, | |||||
| {ge::parser::GETSPAN, {81, 71, 91}}, // The pbtxt only has 62 nodes when test GetSpan sub net. However the | |||||
| {ge::parser::HUBERLOSSGRAD, {8, 9, 10, 20, 21}}, | |||||
| {CLIPBOXES, {8}}, | |||||
| {FASTRCNNPREDICTIONS, {118, 119, 120, 123, 125}}, | |||||
| {RPNPROPOSALS, {75, 85, 97}}, | |||||
| {DECODEBBOX, {24, 28}}, | |||||
| {ROIALIGN, {82, 83, 84}}, | |||||
| {FUSIONBATCHNORM, {8}}, | |||||
| {GETSPAN, {81, 71, 91}}, // The pbtxt only has 62 nodes when test GetSpan sub net. However the | |||||
| {HUBERLOSSGRAD, {8, 9, 10, 20, 21}}, | |||||
| }; | }; | ||||
| // <Types of fusion operators, Name of children operators(Remove the prefixes and/)> | // <Types of fusion operators, Name of children operators(Remove the prefixes and/)> | ||||
| static map<string, vector<string>> tensorflow_fusionop_children_names_map = { | static map<string, vector<string>> tensorflow_fusionop_children_names_map = { | ||||
| {ge::parser::FUSIONBATCHNORM, {"add/y", "add", "Rsqrt", "mul", "mul_1", "mul_2", "sub", "add_1"}}, | |||||
| {ge::parser::GETSPAN, {}}, | |||||
| {ge::parser::HUBERLOSSGRAD, {}}, | |||||
| {FUSIONBATCHNORM, {"add/y", "add", "Rsqrt", "mul", "mul_1", "mul_2", "sub", "add_1"}}, | |||||
| {GETSPAN, {}}, | |||||
| {HUBERLOSSGRAD, {}}, | |||||
| }; | }; | ||||
| // ----------------------------Index table of input and output of fusion operator-------------- | // ----------------------------Index table of input and output of fusion operator-------------- | ||||
| @@ -138,23 +137,23 @@ static map<string, vector<string>> tensorflow_fusionop_children_names_map = { | |||||
| // Generally, the old index is 0. If the new index value is kFusionDisableIndex, the edge can be ignored. | // Generally, the old index is 0. If the new index value is kFusionDisableIndex, the edge can be ignored. | ||||
| // If it is control edge input, the index is graph::kControlSlot(-1). | // If it is control edge input, the index is graph::kControlSlot(-1). | ||||
| static map<string, vector<std::pair<string, vector<int32_t>>>> tensorflow_fusionop_inputs_map = { | static map<string, vector<std::pair<string, vector<int32_t>>>> tensorflow_fusionop_inputs_map = { | ||||
| {ge::parser::FUSIONBATCHNORM, | |||||
| {FUSIONBATCHNORM, | |||||
| {{"mul_1", {0, kFusionDisableIndex}}, | {{"mul_1", {0, kFusionDisableIndex}}, | ||||
| {"mul", {1, 1}}, | {"mul", {1, 1}}, | ||||
| {"sub", {2, kFusionDisableIndex}}, | {"sub", {2, kFusionDisableIndex}}, | ||||
| {"mul_2", {3, kFusionDisableIndex}}, | {"mul_2", {3, kFusionDisableIndex}}, | ||||
| {"add", {4, kFusionDisableIndex}}}}, | {"add", {4, kFusionDisableIndex}}}}, | ||||
| {ge::parser::GETSPAN, {{"transpose", {0}}, {"TensorArray", {1}}, {"transpose_1", {2}}}}, | |||||
| {ge::parser::HUBERLOSSGRAD, {{"Sub_1_grad/Neg", {1}}, {"Abs_grad/Sign", {0}}}}, | |||||
| {GETSPAN, {{"transpose", {0}}, {"TensorArray", {1}}, {"transpose_1", {2}}}}, | |||||
| {HUBERLOSSGRAD, {{"Sub_1_grad/Neg", {1}}, {"Abs_grad/Sign", {0}}}}, | |||||
| }; | }; | ||||
| static map<string, vector<std::pair<string, vector<int32_t>>>> tensorflow_fusionop_outputs_map = { | static map<string, vector<std::pair<string, vector<int32_t>>>> tensorflow_fusionop_outputs_map = { | ||||
| {ge::parser::FUSIONBATCHNORM, {{"add_1", {0}}}}, | |||||
| {ge::parser::GETSPAN, {{"while/Exit_1", {0}}, {"while/Exit_2", {1}}}}, | |||||
| {ge::parser::HUBERLOSSGRAD, {{"Abs_grad/mul", {0}}}}, | |||||
| {FUSIONBATCHNORM, {{"add_1", {0}}}}, | |||||
| {GETSPAN, {{"while/Exit_1", {0}}, {"while/Exit_2", {1}}}}, | |||||
| {HUBERLOSSGRAD, {{"Abs_grad/mul", {0}}}}, | |||||
| }; | }; | ||||
| map<string, vector<std::pair<string, uint32_t>>> tensorflow_fusionop_input_const_weight_index_map = { | map<string, vector<std::pair<string, uint32_t>>> tensorflow_fusionop_input_const_weight_index_map = { | ||||
| {ge::parser::FUSIONBATCHNORM, {{"mul", 0}, {"sub", 1}, {"mul_2", 2}, {"add", 3}}}, | |||||
| {FUSIONBATCHNORM, {{"mul", 0}, {"sub", 1}, {"mul_2", 2}, {"add", 3}}}, | |||||
| }; | }; | ||||
| // Can a string be converted to an integer | // Can a string be converted to an integer | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "common/string_util.h" | #include "common/string_util.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "common/types.h" | |||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "omg/omg_inner_types.h" | #include "omg/omg_inner_types.h" | ||||
| #include "proto/tensorflow/graph.pb.h" | #include "proto/tensorflow/graph.pb.h" | ||||
| @@ -17,15 +17,11 @@ | |||||
| #include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
| #include "common/op_def/ir_pb_converter.h" | #include "common/op_def/ir_pb_converter.h" | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "parser/tensorflow/tensorflow_identity_parser.h" | #include "parser/tensorflow/tensorflow_identity_parser.h" | ||||
| using domi::TENSORFLOW; | using domi::TENSORFLOW; | ||||
| using ge::parser::IDENTITY; | |||||
| using ge::parser::READVARIABLEOP; | |||||
| namespace ge { | namespace ge { | ||||
| REGISTER_OP_PARSER_CREATOR(TENSORFLOW, IDENTITY, TensorFlowIdentityParser); | REGISTER_OP_PARSER_CREATOR(TENSORFLOW, IDENTITY, TensorFlowIdentityParser); | ||||
| REGISTER_OP_PARSER_CREATOR(TENSORFLOW, READVARIABLEOP, TensorFlowIdentityParser); | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -20,10 +20,8 @@ | |||||
| #include "framework/common/util.h" | #include "framework/common/util.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| using domi::TENSORFLOW; | using domi::TENSORFLOW; | ||||
| using ge::parser::MERGE; | |||||
| namespace ge { | namespace ge { | ||||
| Status TensorFlowMergeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) { | Status TensorFlowMergeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_desc) { | ||||
| @@ -22,7 +22,6 @@ | |||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| using domi::TENSORFLOW; | using domi::TENSORFLOW; | ||||
| using namespace ge::parser; | |||||
| namespace ge { | namespace ge { | ||||
| Status TensorFlowNoOpParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | Status TensorFlowNoOpParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | ||||
| @@ -25,7 +25,7 @@ | |||||
| #include "framework/omg/parser/op_parser.h" | #include "framework/omg/parser/op_parser.h" | ||||
| #include "parser/common/op_def/ir_pb_converter.h" | #include "parser/common/op_def/ir_pb_converter.h" | ||||
| #include "parser/common/op_def/operator.h" | #include "parser/common/op_def/operator.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "parser/tensorflow/tensorflow_op_parser.h" | #include "parser/tensorflow/tensorflow_op_parser.h" | ||||
| #include "proto/tensorflow/node_def.pb.h" | #include "proto/tensorflow/node_def.pb.h" | ||||
| @@ -72,7 +72,7 @@ class TensorflowParserBuilder : public TensorflowWeightParserBuilder { | |||||
| } | } | ||||
| bool Finalize() override { | bool Finalize() override { | ||||
| auto op_parser_adapter = ge::parser::MakeShared<TensorflowOpParserAdapter<Param>>(*this); | |||||
| auto op_parser_adapter = ge::MakeShared<TensorflowOpParserAdapter<Param>>(*this); | |||||
| if (op_parser_adapter == nullptr) { | if (op_parser_adapter == nullptr) { | ||||
| GELOGE(FAILED, "Op parser adapter is null."); | GELOGE(FAILED, "Op parser adapter is null."); | ||||
| } | } | ||||
| @@ -102,7 +102,7 @@ class TensorflowOpParserAdapter : public TensorFlowOpParser { | |||||
| Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override { | Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override { | ||||
| const domi::tensorflow::NodeDef *node = static_cast<const domi::tensorflow::NodeDef *>(op_src); | const domi::tensorflow::NodeDef *node = static_cast<const domi::tensorflow::NodeDef *>(op_src); | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| std::shared_ptr<Param> param = ge::parser::MakeShared<Param>(); | |||||
| std::shared_ptr<Param> param = ge::MakeShared<Param>(); | |||||
| if (param == nullptr) { | if (param == nullptr) { | ||||
| GELOGE(domi::FAILED, "Param is null"); | GELOGE(domi::FAILED, "Param is null"); | ||||
| return domi::FAILED; | return domi::FAILED; | ||||
| @@ -26,7 +26,6 @@ using domi::tensorflow::DT_FLOAT; | |||||
| using domi::tensorflow::AttrValue; | using domi::tensorflow::AttrValue; | ||||
| using domi::tensorflow::NodeDef; | using domi::tensorflow::NodeDef; | ||||
| using domi::TENSORFLOW; | using domi::TENSORFLOW; | ||||
| using namespace ge::parser; | |||||
| namespace ge { | namespace ge { | ||||
| // AUTO GEN PLEASE DO NOT MODIFY IT | // AUTO GEN PLEASE DO NOT MODIFY IT | ||||
| @@ -22,10 +22,9 @@ | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "parser/tensorflow/tensorflow_util.h" | #include "parser/tensorflow/tensorflow_util.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "common/math/math_util.h" | |||||
| using domi::TENSORFLOW; | using domi::TENSORFLOW; | ||||
| using namespace ge::parser; | |||||
| namespace ge { | namespace ge { | ||||
| Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc) { | Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc) { | ||||
| @@ -48,7 +47,7 @@ Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &att | |||||
| GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); | GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); | ||||
| real_size *= tmp_dim; | real_size *= tmp_dim; | ||||
| } | } | ||||
| PARSER_INT64_MULCHECK(real_size, size_type); | |||||
| FMK_INT64_MULCHECK(real_size, size_type); | |||||
| ge::TensorUtils::SetSize(ge_desc, real_size * size_type); | ge::TensorUtils::SetSize(ge_desc, real_size * size_type); | ||||
| ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); | ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); | ||||
| GELOGI("after translate tf_desc, datatype: %s, format: %s, real size: %u, size_type: %u", | GELOGI("after translate tf_desc, datatype: %s, format: %s, real size: %u, size_type: %u", | ||||
| @@ -68,7 +67,7 @@ Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||||
| domi::tensorflow::AttrValue output_attr_value; | domi::tensorflow::AttrValue output_attr_value; | ||||
| GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
| GetParserContext().train_flag == true, | |||||
| domi::GetContext().train_flag == true, | |||||
| ge::GeTensorDesc input_desc; | ge::GeTensorDesc input_desc; | ||||
| ge::GeTensorDesc output_desc; | ge::GeTensorDesc output_desc; | ||||
| @@ -26,7 +26,6 @@ using domi::tensorflow::AttrValue; | |||||
| using domi::tensorflow::DataType; | using domi::tensorflow::DataType; | ||||
| using domi::tensorflow::DT_FLOAT; | using domi::tensorflow::DT_FLOAT; | ||||
| using domi::tensorflow::DT_INT32; | using domi::tensorflow::DT_INT32; | ||||
| using namespace ge::parser; | |||||
| namespace { | namespace { | ||||
| const std::string kShapeAttrDtype = "out_type"; | const std::string kShapeAttrDtype = "out_type"; | ||||
| @@ -22,16 +22,14 @@ | |||||
| #include "framework/common/op/attr_value_util.h" | #include "framework/common/op/attr_value_util.h" | ||||
| #include "framework/common/op/op_parser_util.h" | #include "framework/common/op/op_parser_util.h" | ||||
| #include "framework/common/util.h" | #include "framework/common/util.h" | ||||
| #include "framework/omg/parser/parser_inner_ctx.h" | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | |||||
| #include "common/math/math_util.h" | |||||
| using domi::tensorflow::AttrValue; | using domi::tensorflow::AttrValue; | ||||
| using std::vector; | using std::vector; | ||||
| using std::shared_ptr; | using std::shared_ptr; | ||||
| using domi::TENSORFLOW; | using domi::TENSORFLOW; | ||||
| using namespace ge::parser; | |||||
| namespace ge { | namespace ge { | ||||
| Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc) { | Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc) { | ||||
| @@ -52,10 +50,10 @@ Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &att | |||||
| for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { | for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { | ||||
| tmp_dim = ge_desc.GetShape().GetDim(j); | tmp_dim = ge_desc.GetShape().GetDim(j); | ||||
| GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); | GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); | ||||
| PARSER_INT64_MULCHECK(real_size, tmp_dim); | |||||
| FMK_INT64_MULCHECK(real_size, tmp_dim); | |||||
| real_size *= tmp_dim; | real_size *= tmp_dim; | ||||
| } | } | ||||
| PARSER_INT64_MULCHECK(real_size, size_type); | |||||
| FMK_INT64_MULCHECK(real_size, size_type); | |||||
| ge::TensorUtils::SetSize(ge_desc, real_size * size_type); | ge::TensorUtils::SetSize(ge_desc, real_size * size_type); | ||||
| ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); | ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); | ||||
| GELOGD("after translate tf_desc, datatype: %s, format: %s, real size: %u, size_type: %u", | GELOGD("after translate tf_desc, datatype: %s, format: %s, real size: %u, size_type: %u", | ||||
| @@ -112,7 +110,7 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||||
| domi::tensorflow::AttrValue output_attr_value; | domi::tensorflow::AttrValue output_attr_value; | ||||
| GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
| GetParserContext().train_flag == true, ge::GeTensorDesc input_desc; ge::GeTensorDesc output_desc; | |||||
| domi::GetContext().train_flag == true, ge::GeTensorDesc input_desc; ge::GeTensorDesc output_desc; | |||||
| if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value)) { | if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value)) { | ||||
| GE_CHK_BOOL_RET_STATUS(ParseDesc(input_attr_value, input_desc) == SUCCESS, FAILED, "parse input desc failed"); | GE_CHK_BOOL_RET_STATUS(ParseDesc(input_attr_value, input_desc) == SUCCESS, FAILED, "parse input desc failed"); | ||||
| @@ -15,25 +15,25 @@ | |||||
| */ | */ | ||||
| #include "parser/tensorflow/tensorflow_util.h" | #include "parser/tensorflow/tensorflow_util.h" | ||||
| #include <cstdint> | |||||
| #include <cstdlib> | #include <cstdlib> | ||||
| #include <cstdint> | |||||
| #include <iostream> | #include <iostream> | ||||
| #include <memory> | #include <memory> | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
| #include "framework/common/op/ge_op_utils.h" | #include "framework/common/op/ge_op_utils.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "parser/tensorflow/tensorflow_op_parser.h" | #include "parser/tensorflow/tensorflow_op_parser.h" | ||||
| #include "common/math/math_util.h" | |||||
| using domi::tensorflow::DT_INVALID; | using domi::tensorflow::DT_INVALID; | ||||
| namespace ge { | namespace ge { | ||||
| using AttrValueMap = ::google::protobuf::Map<string, domi::tensorflow::AttrValue>; | using AttrValueMap = ::google::protobuf::Map<string, domi::tensorflow::AttrValue>; | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrValue( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrValue( | ||||
| const domi::tensorflow::NodeDef *node_def, const string &attr_name, domi::tensorflow::AttrValue &attr_value) { | |||||
| const domi::tensorflow::NodeDef *node_def, const string &attr_name, domi::tensorflow::AttrValue &attr_value) { | |||||
| GE_CHECK_NOTNULL(node_def); | GE_CHECK_NOTNULL(node_def); | ||||
| const google::protobuf::Map<std::string, domi::tensorflow::AttrValue> &attr = node_def->attr(); | const google::protobuf::Map<std::string, domi::tensorflow::AttrValue> &attr = node_def->attr(); | ||||
| const google::protobuf::Map<std::string, domi::tensorflow::AttrValue>::const_iterator it = attr.find(attr_name); | const google::protobuf::Map<std::string, domi::tensorflow::AttrValue>::const_iterator it = attr.find(attr_name); | ||||
| @@ -46,7 +46,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrVa | |||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::CheckAttrHasType( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::CheckAttrHasType( | ||||
| const domi::tensorflow::AttrValue &attr_value, const string &type) { | |||||
| const domi::tensorflow::AttrValue &attr_value, const string &type) { | |||||
| uint32_t num_set = 0; | uint32_t num_set = 0; | ||||
| #define VALIDATE_FIELD(name, type_string, oneof_case) \ | #define VALIDATE_FIELD(name, type_string, oneof_case) \ | ||||
| do { \ | do { \ | ||||
| @@ -118,7 +118,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Ch | |||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::ParseDataType( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::ParseDataType( | ||||
| const NodeDef *node_src, const string &attr_src, domi::tensorflow::DataType &data_type) { | |||||
| const NodeDef *node_src, const string &attr_src, domi::tensorflow::DataType &data_type) { | |||||
| GE_CHECK_NOTNULL(node_src); | GE_CHECK_NOTNULL(node_src); | ||||
| string node_name = node_src->name(); | string node_name = node_src->name(); | ||||
| @@ -138,7 +138,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Pa | |||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::ParseFromAttrValueList( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::ParseFromAttrValueList( | ||||
| ge::GeTensorDesc &ge_desc, const domi::tensorflow::AttrValue_ListValue &a_list, int32_t i, int32_t &tf_datatype) { | |||||
| ge::GeTensorDesc &ge_desc, const domi::tensorflow::AttrValue_ListValue &a_list, int32_t i, int32_t &tf_datatype) { | |||||
| const std::string SERIALIZE_FORMAT = "serialize_format"; | const std::string SERIALIZE_FORMAT = "serialize_format"; | ||||
| const std::string SERIALIZE_DATATYPE = "serialize_datatype"; | const std::string SERIALIZE_DATATYPE = "serialize_datatype"; | ||||
| const std::string SERIALIZE_SHAPE = "serialize_shape"; | const std::string SERIALIZE_SHAPE = "serialize_shape"; | ||||
| @@ -162,7 +162,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::ParseFromA | |||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::TransTensorDescriptor( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::TransTensorDescriptor( | ||||
| const domi::tensorflow::AttrValue &attr_value, ParserOperator *op, const uint32_t io, const string &type) { | |||||
| const domi::tensorflow::AttrValue &attr_value, ParserOperator *op, const uint32_t io, const string &type) { | |||||
| GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
| if (!attr_value.has_list()) { | if (!attr_value.has_list()) { | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -191,9 +191,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Tr | |||||
| // The shape infered by fusedbatchnormgrad and mean calling tensorflow is not accurate. | // The shape infered by fusedbatchnormgrad and mean calling tensorflow is not accurate. | ||||
| // Here, special treatment is given to the two operators. | // Here, special treatment is given to the two operators. | ||||
| // Adjust shape to fit resnet50 network only. | // Adjust shape to fit resnet50 network only. | ||||
| GE_IF_BOOL_EXEC((type == ge::parser::FUSEDBATCHNORMGRAD) && (tmp_dim == 0), ge_desc.SetShape(ge::GeShape()); | |||||
| break;); | |||||
| GE_IF_BOOL_EXEC((type == ge::parser::MEAN) && (tmp_dim == 0), vector<int64_t> data_dim = {tmp_dim}; | |||||
| GE_IF_BOOL_EXEC((type == ge::FUSEDBATCHNORMGRAD) && (tmp_dim == 0), ge_desc.SetShape(ge::GeShape()); break;); | |||||
| GE_IF_BOOL_EXEC((type == ge::MEAN) && (tmp_dim == 0), vector<int64_t> data_dim = {tmp_dim}; | |||||
| ge_desc.SetShape(ge::GeShape(data_dim)); break;); | ge_desc.SetShape(ge::GeShape(data_dim)); break;); | ||||
| } | } | ||||
| ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); | ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); | ||||
| @@ -215,7 +214,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Tr | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TensorFlowUtil::AddNodeAttr( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TensorFlowUtil::AddNodeAttr( | ||||
| const string &attr_name, const domi::tensorflow::AttrValue &value, domi::tensorflow::NodeDef *node_def) { | |||||
| const string &attr_name, const domi::tensorflow::AttrValue &value, domi::tensorflow::NodeDef *node_def) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_INFO(node_def == nullptr, return, "input parameter is null."); | GE_CHK_BOOL_TRUE_EXEC_INFO(node_def == nullptr, return, "input parameter is null."); | ||||
| node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value)); | node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value)); | ||||
| } | } | ||||
| @@ -26,7 +26,7 @@ | |||||
| #include "external/graph/attr_value.h" | #include "external/graph/attr_value.h" | ||||
| #include "external/graph/graph.h" | #include "external/graph/graph.h" | ||||
| #include "external/graph/operator.h" | #include "external/graph/operator.h" | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "framework/common/types.h" | |||||
| #include "framework/omg/omg_inner_types.h" | #include "framework/omg/omg_inner_types.h" | ||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
| @@ -22,8 +22,6 @@ | |||||
| #include "parser/tensorflow/tensorflow_op_parser.h" | #include "parser/tensorflow/tensorflow_op_parser.h" | ||||
| #include "parser/tensorflow/tensorflow_parser_register.h" | #include "parser/tensorflow/tensorflow_parser_register.h" | ||||
| using namespace ge::parser; | |||||
| namespace ge { | namespace ge { | ||||
| Status ParseParams(const Message *op_src, VarIsInitializedOpOperator *op) { | Status ParseParams(const Message *op_src, VarIsInitializedOpOperator *op) { | ||||
| GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
| @@ -32,7 +32,6 @@ | |||||
| using domi::tensorflow::AttrValue; | using domi::tensorflow::AttrValue; | ||||
| using domi::tensorflow::NodeDef; | using domi::tensorflow::NodeDef; | ||||
| using domi::tensorflow::TensorShapeProto; | using domi::tensorflow::TensorShapeProto; | ||||
| using namespace ge::parser; | |||||
| namespace ge { | namespace ge { | ||||
| const std::string SERIALIZE_FORMAT = "serialize_format"; | const std::string SERIALIZE_FORMAT = "serialize_format"; | ||||